GitOrigin-RevId: 6a9d5beba2
tags/v1.3.0
@@ -11,6 +11,7 @@ from typing import Iterable, Union | |||||
import numpy as np | import numpy as np | ||||
from .._imperative_rt import VarNode | |||||
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | ||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
@@ -59,7 +60,7 @@ def astype(x, dtype): | |||||
def convert_single_value(v, *, dtype=None, device=None): | def convert_single_value(v, *, dtype=None, device=None): | ||||
if isinstance(v, Tensor): | |||||
if isinstance(v, (Tensor, VarNode)): | |||||
if not is_quantize(v.dtype): | if not is_quantize(v.dtype): | ||||
v = astype(v, dtype) | v = astype(v, dtype) | ||||
else: | else: | ||||
@@ -12,11 +12,12 @@ import functools | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
from ..core._imperative_rt.graph import VarNode | |||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import Elemwise | from ..core.ops.builtin import Elemwise | ||||
from ..core.tensor import utils | from ..core.tensor import utils | ||||
from ..core.tensor.array_method import _elwise_apply | from ..core.tensor.array_method import _elwise_apply | ||||
from ..core.tensor.utils import isscalar, setscalar | |||||
from ..core.tensor.utils import astype, isscalar, setscalar | |||||
from ..device import get_default_device | from ..device import get_default_device | ||||
from ..jit.tracing import is_tracing | from ..jit.tracing import is_tracing | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
@@ -77,7 +78,7 @@ __all__ = [ | |||||
def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
tensor_args = list(filter(lambda x: isinstance(x, Tensor), args)) | |||||
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) | |||||
if len(tensor_args) == 0: | if len(tensor_args) == 0: | ||||
dtype = utils.dtype_promotion(args) | dtype = utils.dtype_promotion(args) | ||||
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | ||||
@@ -109,7 +110,7 @@ def _elwise(*args, mode): | |||||
Elemwise.Mode.ROUND, | Elemwise.Mode.ROUND, | ||||
) and np.issubdtype(args[0].dtype, np.integer): | ) and np.issubdtype(args[0].dtype, np.integer): | ||||
return args[0] | return args[0] | ||||
args = tuple(map(lambda x: x.astype("float32"), args)) | |||||
args = tuple(map(lambda x: astype(x, "float32"), args)) | |||||
return _elwise_apply(args, mode) | return _elwise_apply(args, mode) | ||||
@@ -65,7 +65,6 @@ def get_owner_opr_inputs(var: VarNode) -> List[VarNode]: | |||||
""" | """ | ||||
Gets the inputs of owner opr of a variable. | Gets the inputs of owner opr of a variable. | ||||
""" | """ | ||||
assert isinstance(var, VarNode) | |||||
return var.owner.inputs | return var.owner.inputs | ||||
@@ -74,7 +73,6 @@ def get_owner_opr_type(var: VarNode) -> str: | |||||
Gets the type of owner opr of a variable. | Gets the type of owner opr of a variable. | ||||
""" | """ | ||||
assert isinstance(var, VarNode) | |||||
return var.owner.type | return var.owner.type | ||||
@@ -109,7 +107,7 @@ def graph_traversal(outputs: VarNode): | |||||
var2oprs = collections.defaultdict(list) | var2oprs = collections.defaultdict(list) | ||||
opr2receivers = collections.defaultdict(list) | opr2receivers = collections.defaultdict(list) | ||||
queue = list(map(lambda x: x.owner, outputs)) | |||||
queue = list(set(map(lambda x: x.owner, outputs))) | |||||
visited = set(map(lambda x: x.id, queue)) | visited = set(map(lambda x: x.id, queue)) | ||||
# iterate through whole comp_graph, fill in meta information | # iterate through whole comp_graph, fill in meta information | ||||
@@ -143,12 +141,15 @@ def graph_traversal(outputs: VarNode): | |||||
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | ||||
def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]: | |||||
def get_oprs_seq( | |||||
outputs: List[VarNode], prune_reshape=False, prune_immtensor=True | |||||
) -> List[OperatorNode]: | |||||
""" | """ | ||||
Gets oprs in some topological order for a dumped model. | Gets oprs in some topological order for a dumped model. | ||||
:param outputs: model outputs. | :param outputs: model outputs. | ||||
:param prune_reshape: whether to prune the useless operators during inference. | |||||
:param prune_reshape: whether to prune the useless operators used by Reshape opr during inference. | |||||
:param prune_immtensor: whether to prune the ImmutableTensor opr. | |||||
:return: opr list with some correct execution order. | :return: opr list with some correct execution order. | ||||
""" | """ | ||||
@@ -160,9 +161,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo | |||||
opr_id = indegree2opr[0].pop() | opr_id = indegree2opr[0].pop() | ||||
opr = map_oprs[opr_id] | opr = map_oprs[opr_id] | ||||
nr_remain -= 1 | nr_remain -= 1 | ||||
# skip const value generation operator | |||||
if get_opr_type(opr) != "ImmutableTensor": | |||||
if opr.type != "ImmutableTensor" or not prune_immtensor: | |||||
oprs_seq.append(opr) | oprs_seq.append(opr) | ||||
for post_id in opr2receivers[opr_id]: | for post_id in opr2receivers[opr_id]: | ||||
@@ -0,0 +1,682 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import collections | |||||
import fnmatch | |||||
import itertools | |||||
import re | |||||
from collections import OrderedDict | |||||
from typing import Dict, List | |||||
import numpy as np | |||||
from ..core._imperative_rt import ComputingGraph | |||||
from ..core.tensor import megbrain_graph as G | |||||
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | |||||
from .network_node import ( | |||||
NetworkNode, | |||||
Host2DeviceCopy, | |||||
ImmutableTensor, | |||||
OpNode, | |||||
VarNode, | |||||
str_to_mge_class, | |||||
) | |||||
class Network: | |||||
def __init__(self): | |||||
self.input_vars = [] # input var of graph | |||||
self._orig_inputs = [] | |||||
self.output_vars = [] # output var of graph | |||||
self._orig_outputs = [] | |||||
self.all_oprs_map = OrderedDict() | |||||
self.all_vars_map = OrderedDict() | |||||
self.graph = ComputingGraph() | |||||
@classmethod | |||||
def load(cls, model_path: str, outspec: List[str] = None): | |||||
""" | |||||
Loads a computing graph as a Network object. | |||||
:param model_path: file path of mge model. | |||||
:param outspec: only load the subgraph with outspec as its endpoints. | |||||
""" | |||||
self = cls() | |||||
_, _, outputs = G.load_graph(model_path) | |||||
if outspec is not None: | |||||
output_spec = outspec.copy() | |||||
all_vars = get_dep_vars(outputs) + outputs | |||||
new_outputs = {} | |||||
for i in all_vars: | |||||
if i.name in output_spec: | |||||
new_outputs[i.name] = i | |||||
output_spec.remove(i.name) | |||||
assert len(output_spec) == 0, "Can not find {} in this model".format( | |||||
output_spec | |||||
) | |||||
outputs = [new_outputs[i] for i in outspec] | |||||
self._orig_outputs = outputs | |||||
self.add_dep_oprs(*outputs) | |||||
for x in self._orig_inputs: | |||||
self.input_vars.append(self._get_var(x)) | |||||
for x in self._orig_outputs: | |||||
self.output_vars.append(self._get_var(x)) | |||||
self.graph = self._orig_outputs[0].graph | |||||
return self | |||||
def _compile(self): | |||||
self.all_oprs_map = {} | |||||
self.all_vars_map = {} | |||||
for opr in self.all_oprs: | |||||
if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)): | |||||
opr.compile(self.graph) | |||||
else: | |||||
opr.compile() | |||||
if opr.name is not None: | |||||
opr._opr.name = opr.name | |||||
self.all_oprs_map[opr._opr.id] = opr | |||||
for o in opr.outputs: | |||||
self.all_vars_map[o.var.id] = o | |||||
def dump( | |||||
self, | |||||
file, | |||||
*, | |||||
keep_var_name: int = 1, | |||||
keep_opr_name: bool = False, | |||||
keep_param_name: bool = False, | |||||
keep_opr_priority: bool = False, | |||||
strip_info_file=None, | |||||
append_json=False, | |||||
optimize_for_inference=True, | |||||
append=False, | |||||
**kwargs | |||||
): | |||||
""" | |||||
Serializes graph to file. | |||||
:param file: output file, could be file object or filename. | |||||
:param append: whether output is appended to ``file``. | |||||
Only works when ``file`` is str. | |||||
:param keep_var_name: level for keeping variable names: | |||||
* 0: none of the names are kept | |||||
* 1: (default)keep names of output vars | |||||
* 2: keep names of all (output and internal) vars | |||||
:param keep_opr_name: whether to keep operator names. | |||||
:param keep_param_name: whether to keep param names, so param values can be | |||||
easily manipulated after loading model | |||||
:param keep_opr_priority: whether to keep priority setting for operators | |||||
:param strip_info_file: a string for path or a file handler. if is not None, | |||||
then the dump information for code strip would be written to ``strip_info_file`` | |||||
:param append_json: will be check when `strip_info_file` is not None. if set | |||||
true, the information for code strip will be append to strip_info_file. | |||||
if set false, will rewrite strip_info_file | |||||
:param optimize_for_inference: enbale optmizations, | |||||
will skip all optimize options if this is False. Default: True | |||||
:Keyword Arguments: | |||||
* enable_io16xc32 -- | |||||
whether to use float16 for I/O between oprs and use | |||||
float32 as internal computation precision. Note the output var would be | |||||
changed to float16. | |||||
* enable_ioc16 -- | |||||
whether to use float16 for both I/O and computation | |||||
precision. | |||||
* enable_hwcd4 -- | |||||
whether to use NHWCD4 data layout. This is faster on some | |||||
OpenCL backend. | |||||
* enable_nchw88 -- | |||||
whether to use NCHW88 data layout, currently | |||||
used in X86 AVX backend. | |||||
* enable_nchw44 -- | |||||
whether to use NCHW44 data layout, currently | |||||
used in arm backend. | |||||
* enable_nchw44_dot -- | |||||
whether to use NCHW44_dot data layout, currently | |||||
used in armv8.2+dotprod backend. | |||||
* enable_nchw4 -- | |||||
whether to use NCHW4 data layout, currently | |||||
used in nvidia backend(based on cudnn). | |||||
* enable_nchw32 -- | |||||
whether to use NCHW32 data layout, currently | |||||
used in nvidia backend with tensorcore(based on cudnn). | |||||
* enable_chwn4 -- | |||||
whether to use CHWN4 data layout, currently | |||||
used in nvidia backend with tensorcore. | |||||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||||
into one opr. | |||||
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||||
input for inference on nvidia backend(this optimization pass will | |||||
result in mismatch of the precision of output of training and | |||||
inference) | |||||
""" | |||||
self._compile() | |||||
out = [G.VarNode(var.var) for var in self.output_vars] | |||||
if optimize_for_inference: | |||||
out = G.optimize_for_inference(out, **kwargs) | |||||
dump_content, _ = G.dump_graph( | |||||
out, | |||||
keep_var_name=keep_var_name, | |||||
keep_opr_name=keep_opr_name, | |||||
keep_param_name=keep_param_name, | |||||
keep_opr_priority=keep_opr_priority, | |||||
strip_info_file=strip_info_file, | |||||
append_json=append_json, | |||||
) | |||||
if isinstance(file, str): | |||||
permission = "wb" if append == False else "ab" | |||||
file = open(file, permission) | |||||
file.write(dump_content) | |||||
def make_const(self, data, name=None, device=None): | |||||
"""Makes an ImmutableTensor OpNode to provide a parameter for the network. | |||||
""" | |||||
node = ImmutableTensor(data, name, device, self.graph) | |||||
node.compile(self.graph) | |||||
return node.outputs[0] | |||||
def make_input_node(self, shape, dtype, name=None, device=None): | |||||
"""Makes a Host2DeviceCopy OpNode to provide an input varnode for the network. | |||||
""" | |||||
node = Host2DeviceCopy(shape, dtype, name, device) | |||||
node.compile(self.graph) | |||||
return node.outputs[0] | |||||
def add_output(self, *vars: VarNode): | |||||
"""Adds vars into the network output node list | |||||
""" | |||||
for var in vars: | |||||
if var not in self.output_vars: | |||||
self.output_vars.append(var) | |||||
def remove_output(self, *vars: VarNode): | |||||
"""Removes vars from the network output node list. | |||||
""" | |||||
for var in vars: | |||||
if var in self.output_vars: | |||||
self.output_vars.remove(var) | |||||
def add_dep_oprs(self, *vars): | |||||
"""Adds dependent opnodes and varnodes of vars into network | |||||
""" | |||||
oprs = get_oprs_seq(vars, False, False) | |||||
for mge_opr in oprs: | |||||
if get_opr_type(mge_opr) == "Host2DeviceCopy": | |||||
self._orig_inputs.extend(mge_opr.outputs) | |||||
opr = self._add_opr(mge_opr) | |||||
if opr is not None: | |||||
for x in mge_opr.inputs: | |||||
opr.add_inp_var(self._get_var(x)) | |||||
# set out var | |||||
for x in mge_opr.outputs: | |||||
opr.add_out_var(self._get_var(x)) | |||||
return [self.all_vars_map[var.id] for var in vars] | |||||
def modify_opr_names(self, modifier): | |||||
"""Modifies names of operators **inplace**; useful for merging loaded | |||||
network into another network | |||||
:param modifier: a string to be prepended to the name, or a function | |||||
that maps from name to name | |||||
:type modifier: str or callable | |||||
""" | |||||
if isinstance(modifier, str): | |||||
om = modifier | |||||
modifier = lambda v: "{}.{}".format(om, v) | |||||
assert isinstance(modifier, collections.Callable) | |||||
for i in self.all_oprs: | |||||
v0 = i.name | |||||
v1 = modifier(v0) | |||||
assert isinstance(v1, str) | |||||
i.name = v1 | |||||
def reset_batch_size(self, batchsize, *, blacklist=()): | |||||
"""Helper for reset batch size; first dimension of all data providers | |||||
not in blacklist are assumed to be the batch size | |||||
:param blacklist: data provider names whose first dimension is not | |||||
batchbatch size | |||||
""" | |||||
blacklist = set(blacklist) | |||||
prev_batchsize = None | |||||
for i in self.data_providers_filter: | |||||
if i.name in blacklist: | |||||
blacklist.remove(i.name) | |||||
else: | |||||
shp = list(i.shape) | |||||
if prev_batchsize is None: | |||||
prev_batchsize = shp[0] | |||||
else: | |||||
assert prev_batchsize == shp[0], ( | |||||
"batchsize mismatch: batchsize={} " | |||||
"shape={} dp={}".format(prev_batchsize, shp, i.name) | |||||
) | |||||
shp[0] = batchsize | |||||
i.shape = tuple(shp) | |||||
assert prev_batchsize is not None, "no data provider found" | |||||
assert not blacklist, "unused items in blacklist: {}".format(blacklist) | |||||
def replace_vars(self, repl_dict: Dict[VarNode, VarNode]): | |||||
""" | |||||
Replaces vars in the graph. | |||||
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | |||||
""" | |||||
for var in self.all_vars: | |||||
if var in repl_dict: | |||||
repl_var = repl_dict[var] | |||||
owner = repl_var.owner | |||||
idx = owner.outputs.index(repl_var) | |||||
owner.outputs[idx] = var | |||||
var.__dict__.update(repl_var.__dict__) | |||||
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | |||||
""" | |||||
Replaces operators in the graph. | |||||
:param oprmap: the map {old_opr: new_opr} that specifies how to replace the operators. | |||||
""" | |||||
for opr in self.all_oprs: | |||||
if opr in repl_dict: | |||||
assert len(opr.outputs) == len( | |||||
repl_dict[opr].outputs | |||||
), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) | |||||
repl_dict[opr].outputs = opr.outputs | |||||
for ind, var in enumerate(opr.outputs): | |||||
var.owner = repl_dict[opr] | |||||
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | |||||
def get_opr_by_type(self, oprcls, unique=True): | |||||
assert issubclass(oprcls, OpNode) | |||||
rst = self.opr_filter.type(oprcls).as_list() | |||||
if unique: | |||||
assert len(rst) == 1, "{} operators of type {} found".format( | |||||
len(rst), oprcls | |||||
) | |||||
(rst,) = rst | |||||
return rst | |||||
def get_opr_by_name(self, name, unique=True): | |||||
rst = self.opr_filter.name(name).as_list() | |||||
if unique: | |||||
assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name) | |||||
(rst,) = rst | |||||
return rst | |||||
def get_var_by_name(self, name, unique=True): | |||||
rst = self.var_filter.name(name).as_list() | |||||
if unique: | |||||
assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name) | |||||
(rst,) = rst | |||||
return rst | |||||
def get_var_receive_oprs(self, var): | |||||
""" Gets all oprs which use var as input | |||||
""" | |||||
return self.opr_filter.has_input(var).as_list() | |||||
def get_dep_oprs(self, var): | |||||
"""Gets dependent oprs of var | |||||
""" | |||||
return get_oprs_seq(var, False, False) | |||||
@property | |||||
def opr_filter(self): | |||||
"""Filter on all opnodes of the Network. | |||||
""" | |||||
oprs = self.all_oprs | |||||
return NodeFilter(itertools.islice(oprs, len(oprs))) | |||||
@property | |||||
def var_filter(self): | |||||
"""Filter on all varnode of the Network. | |||||
""" | |||||
vars = self.all_vars | |||||
return NodeFilter(itertools.islice(vars, len(vars))) | |||||
@property | |||||
def params_filter(self): # all immutable tensor | |||||
"""Filter on all parameters (ImmutableTensor Opr) of the Network | |||||
""" | |||||
return self.opr_filter.param_provider() | |||||
@property | |||||
def data_providers_filter(self): # all host2devicecopy | |||||
"""Filter on all input nodes (Host2DeviceCopy Opr) of the Network | |||||
""" | |||||
return self.opr_filter.data_provider() | |||||
@property | |||||
def dest_vars(self): | |||||
"""Output varnodes of the Network. | |||||
""" | |||||
return self.output_vars | |||||
@property | |||||
def all_oprs(self): | |||||
return get_oprs_seq(self.output_vars, False, False) | |||||
@property | |||||
def all_vars(self): | |||||
return get_dep_vars(self.output_vars) | |||||
@property | |||||
def all_vars_dict(self): | |||||
return self.var_filter.as_dict() | |||||
@property | |||||
def all_oprs_dict(self): | |||||
return self.opr_filter.as_dict() | |||||
# used for loading and building graph | |||||
def _add_opr(self, x): | |||||
# TODO: use megbrain C++ RTTI to replace type string | |||||
if x.id not in self.all_oprs_map: | |||||
self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x) | |||||
return self.all_oprs_map[x.id] | |||||
else: | |||||
return None | |||||
def _get_opr(self, x): | |||||
if x.id in self.all_oprs_map: | |||||
return self.all_oprs_map[x.id] | |||||
else: | |||||
return None | |||||
def _get_var(self, x): | |||||
# auto convert to VarNode of Network | |||||
if x.id not in self.all_vars_map: | |||||
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | |||||
return self.all_vars_map[x.id] | |||||
def as_varnode(obj): | |||||
"""convert a :class:`.VarNode` compatible object to :class:`.VarNode`. | |||||
:param obj: it must be one of the following: | |||||
1. a :class:`.VarNode` object | |||||
2. a :class:`.OpNode` object that has unique output | |||||
3. an iterable that produces either type 1 or 2, with length 1 | |||||
:rtype: :class:`.VarNode` | |||||
""" | |||||
if type(obj) is VarNode: | |||||
return obj | |||||
if isinstance(obj, OpNode): | |||||
assert len(obj.outputs) == 1, ( | |||||
"operator {} must have one output to be converted to VarNode; " | |||||
"got {} actually".format(obj, len(obj.outputs)) | |||||
) | |||||
ret = obj.outputs[0] | |||||
assert type(ret) is VarNode | |||||
return ret | |||||
assert isinstance( | |||||
obj, collections.Iterable | |||||
), "{} is not compatible with VarNode".format(obj) | |||||
val = list(obj) | |||||
assert ( | |||||
len(val) == 1 | |||||
), "can not convert sequence of length {} to VarNode ({})".format( | |||||
len(val), (lambda s: s if len(s) < 50 else s[:50] + " ...")(str(val)) | |||||
) | |||||
return as_varnode(val[0]) | |||||
def as_oprnode(obj): | |||||
"""convert a :class:`.OpNode` compatible object to | |||||
:class:`.OpNode`; it works like :func:`as_varnode`.""" | |||||
if type(obj) is VarNode: | |||||
return obj.owner | |||||
if isinstance(obj, OpNode): | |||||
return obj | |||||
assert isinstance( | |||||
obj, collections.Iterable | |||||
), "{} is not compatible with OpNode".format(obj) | |||||
val = list(obj) | |||||
assert ( | |||||
len(val) == 1 | |||||
), "can not convert sequence of length {} to " "OpNode({})".format(len(val), val) | |||||
return as_oprnode(val[0]) | |||||
class NodeFilter: | |||||
"""Filter on node iterator. This class is an iterator of | |||||
:class:`.NetworkNode` objects and multiple filtering conditions and | |||||
mappers can be chained. | |||||
Example:: | |||||
# find all :class:`.ImmutableTensor` nodes | |||||
for i in NodeFilter(node_iter).param_provider(): | |||||
print(i) | |||||
# find all :class:`.ImmutableTensor` nodes that end with ':W' | |||||
for i in NodeFilter(node_iter).param_provider().name('*:W'): | |||||
print(i) | |||||
# number of inputs | |||||
nr_input = NodeFilter(node_iter).data_provider().as_count() | |||||
""" | |||||
_iter = None | |||||
def __init__(self, node_iter): | |||||
""" | |||||
:param node_iter: iterator to :class:`.NetworkNode`, or a | |||||
:class:`.VarNode`-compatible object; in the later case, its | |||||
dependent oprs would be used | |||||
""" | |||||
if isinstance(node_iter, VarNode): | |||||
oprs = get_oprs_seq(node_iter, False, False) | |||||
node_iter = itertools.islice(oprs, len(oprs) - 1) | |||||
if isinstance(node_iter, OpNode): | |||||
oprs = get_oprs_seq(node_iter.inputs, False, False) | |||||
node_iter = itertools.islice(oprs, len(oprs) - 1) | |||||
assert isinstance(node_iter, collections.Iterable) | |||||
if (not isinstance(node_iter, NodeFilter)) and type( | |||||
self | |||||
) is not NodeFilterCheckType: | |||||
node_iter = NodeFilterCheckType(node_iter, NetworkNode) | |||||
self._iter = node_iter | |||||
@classmethod | |||||
def make_all_deps(cls, *dest_vars): | |||||
"""make a :class:`NodeFilter` that contains all deps of given vars""" | |||||
return cls(list(get_oprs_seq(dest_vars, False, False))) | |||||
def __iter__(self): | |||||
"""to be overwritten by subclass to implement filters""" | |||||
return iter(self._iter) | |||||
def type(self, node_type): | |||||
"""filter by specific node type | |||||
:param node_type: node type class | |||||
:return: a new :class:`NodeFilter` object | |||||
""" | |||||
return NodeFilterType(self, node_type) | |||||
def check_type(self, node_type): | |||||
"""assert that all oprs produced by this iterator are instances of | |||||
certain type | |||||
:param node_type: node type class | |||||
:return: a new :class:`NodeFilter` object | |||||
:raises TypeError: if type check failed | |||||
""" | |||||
return NodeFilterCheckType(self, node_type) | |||||
def not_type(self, node_type): | |||||
"""remove oprs of specific type | |||||
:param node_type: node type class | |||||
:return: a new :class:`NodeFilter` object | |||||
""" | |||||
return NodeFilterNotType(self, node_type) | |||||
def param_provider(self): | |||||
"""get :class:`.ParamProvider` oprs; shorthand for | |||||
``.type(ParamProvider)``""" | |||||
return self.type(ImmutableTensor) | |||||
def data_provider(self): | |||||
"""get :class:`.DataProvider` oprs; shorthand for | |||||
``.type(DataProvider)``""" | |||||
return self.type(Host2DeviceCopy) | |||||
def name(self, pattern, ignorecase=True): | |||||
"""filter by node name | |||||
:param pattern: a string in glob syntax that can contain ``?`` and | |||||
``*`` to match a single or arbitrary characters. | |||||
:type pattern: :class:`str` | |||||
:param ignorecase: whether to ignroe case | |||||
:type ignorecase: bool | |||||
:return: a new :class:`NodeFilter` object | |||||
""" | |||||
return NodeFilterName(self, pattern, ignorecase) | |||||
def has_input(self, var): | |||||
"""an opr is kept if it has given var as one of its inputs | |||||
:param var: var node to checked | |||||
:return: a new :class:`NodeFilter` object | |||||
""" | |||||
return NodeFilterHasInput(self, var) | |||||
def as_list(self): | |||||
"""consume this iterator and return its content as a list | |||||
:rtype: [:class:`.GraphNodeBase`] | |||||
""" | |||||
return list(self) | |||||
def as_unique(self): | |||||
"""assert that this iterator yields only one node and return it | |||||
:return: the unique node | |||||
:rtype: :class:`.GraphNodeBase` | |||||
:raises ValueError: if this iterator does not yield a unique node | |||||
""" | |||||
(opr,) = self | |||||
return opr | |||||
def as_dict(self): | |||||
"""construct an ordered dict to map from node names to objects in | |||||
this iterator | |||||
:rtype: :class:`OrderedDict` | |||||
""" | |||||
return collections.OrderedDict((i.name, i) for i in self) | |||||
def as_count(self): | |||||
"""consume this iterator and get the number of elements | |||||
:rtype: int | |||||
""" | |||||
return sum(1 for _ in self) | |||||
class NodeFilterType(NodeFilter): | |||||
"""see :meth:`NodeFilter.type`""" | |||||
_node_type = None | |||||
def __init__(self, node_iter, node_type): | |||||
assert issubclass(node_type, NetworkNode), "bad opr type: {}".format( | |||||
node_type | |||||
) | |||||
super().__init__(node_iter) | |||||
self._node_type = node_type | |||||
def __iter__(self): | |||||
for i in self._iter: | |||||
if isinstance(i, self._node_type): | |||||
yield i | |||||
class NodeFilterNotType(NodeFilterType): | |||||
"""see :meth:`NodeFilter.not_type`""" | |||||
def __iter__(self): | |||||
for i in self._iter: | |||||
if not isinstance(i, self._node_type): | |||||
yield i | |||||
class NodeFilterCheckType(NodeFilterType): | |||||
"""see :meth:`NodeFilter.check_type`""" | |||||
def __iter__(self): | |||||
for i in self._iter: | |||||
if not isinstance(i, self._node_type): | |||||
raise TypeError( | |||||
"all nodes should be {}; got {!r}".format(self._node_type, i) | |||||
) | |||||
yield i | |||||
class NodeFilterHasInput(NodeFilter): | |||||
"""see :meth:`NodeFilter.has_input`""" | |||||
_var = None | |||||
def __init__(self, node_iter, var): | |||||
var = as_varnode(var) | |||||
super().__init__(node_iter) | |||||
self.var = var | |||||
def __iter__(self): | |||||
for i in self._iter: | |||||
assert isinstance( | |||||
i, OpNode | |||||
), "has_input() must be used with OpNode; " "got {!r}".format(i) | |||||
if self.var in i.inputs: | |||||
yield i | |||||
class NodeFilterName(NodeFilter): | |||||
"""see :meth:`NodeFilter.name`""" | |||||
_re = None | |||||
def __init__(self, node_iter, pattern, ignorecase): | |||||
super().__init__(node_iter) | |||||
self._re = self.make_re(pattern, ignorecase) | |||||
@classmethod | |||||
def make_re(cls, pattern, ignorecase=True): | |||||
assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern) | |||||
assert isinstance(ignorecase, bool) | |||||
flags = 0 | |||||
if ignorecase: | |||||
flags |= re.IGNORECASE | |||||
return re.compile(fnmatch.translate(pattern), flags=flags) | |||||
def __iter__(self): | |||||
for i in self._iter: | |||||
if self._re.match(i.name): | |||||
yield i |
@@ -0,0 +1,628 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import json | |||||
import sys | |||||
from typing import Callable | |||||
from ..core import _imperative_rt as rt | |||||
from ..core._wrap import Device | |||||
from ..core.ops import builtin | |||||
from ..core.tensor.megbrain_graph import InputNode | |||||
from ..tensor import Tensor | |||||
from .comp_graph_tools import replace_vars | |||||
class NetworkNode: | |||||
pass | |||||
class VarNode(NetworkNode): | |||||
def __init__(self, owner_opr=None, name=None): | |||||
self.var = None | |||||
self.owner = owner_opr | |||||
self.name = name | |||||
self.id = id(self) | |||||
@classmethod | |||||
def load(cls, sym_var, owner_opr): | |||||
obj = cls() | |||||
obj.var = sym_var # mgb varnode | |||||
obj.name = sym_var.name | |||||
obj.owner = owner_opr | |||||
return obj | |||||
@property | |||||
def shape(self): | |||||
rst = None | |||||
if self.var: | |||||
try: | |||||
rst = self.var.shape | |||||
except: | |||||
rst = None | |||||
return rst | |||||
@property | |||||
def dtype(self): | |||||
return self.var.dtype if self.var else None | |||||
def set_owner_opr(self, owner_opr): | |||||
self.owner_opr = owner_opr | |||||
class OpNode(NetworkNode): | |||||
opdef = None | |||||
type = None | |||||
def __init__(self): | |||||
self.inputs = [] | |||||
self.outputs = [] | |||||
self.params = {} | |||||
self._opr = None # mgb opnode | |||||
self.id = id(self) | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = cls() | |||||
obj.params = json.loads(opr.params) | |||||
obj.name = opr.name | |||||
obj._opr = opr | |||||
return obj | |||||
def compile(self, graph=None): | |||||
op = self.opdef(**self.params) | |||||
args = [i.var for i in self.inputs] | |||||
outputs = rt.invoke_op(op, args) | |||||
assert len(outputs) == len(self.outputs) | |||||
self._opr = outputs[0].owner | |||||
for i in range(len(self.outputs)): | |||||
self.outputs[i].var = outputs[i] | |||||
self.outputs[i].var.name = self.outputs[i].name | |||||
assert self.outputs[i].owner is self | |||||
def add_inp_var(self, x): | |||||
self.inputs.append(x) | |||||
def add_out_var(self, x): | |||||
self.outputs.append(x) | |||||
def str_to_mge_class(classname): | |||||
# TODO: use megbrain C++ RTTI to replace type string | |||||
if classname == "RNGOpr<MegDNNOpr>": | |||||
classname = "RNGOpr" | |||||
oprcls = getattr(sys.modules[__name__], classname, None) | |||||
return oprcls if oprcls else ReadOnlyOpNode | |||||
class Host2DeviceCopy(OpNode): | |||||
type = "Host2DeviceCopy" | |||||
def __init__(self, shape=None, dtype=None, name=None, device=None): | |||||
super().__init__() | |||||
self.shape = shape | |||||
self.dtype = dtype | |||||
self.name = name | |||||
self.device = Device(device).to_c() if device else Device("xpux").to_c() | |||||
self.outputs = [] | |||||
@classmethod | |||||
def load(cls, opr): | |||||
self = cls() | |||||
self.outputs = [] | |||||
assert len(opr.outputs) == 1, "wrong number of outputs" | |||||
self.shape = opr.outputs[0].shape | |||||
self.dtype = opr.outputs[0].dtype | |||||
self.name = opr.outputs[0].name | |||||
self.device = opr.outputs[0].comp_node | |||||
self._opr = opr | |||||
return self | |||||
def compile(self, graph): | |||||
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||||
self._opr = outputs.owner | |||||
if len(self.outputs) == 0: | |||||
self.outputs.append(VarNode(self, self.name)) | |||||
self.outputs[0].var = outputs | |||||
assert self.outputs[0].owner is self | |||||
class ImmutableTensor(OpNode): | |||||
type = "ImmutableTensor" | |||||
def __init__(self, data=None, name=None, device=None, graph=None): | |||||
super().__init__() | |||||
self.name = name | |||||
self.outputs = [] | |||||
self.graph = graph | |||||
if data is not None: | |||||
self.set_value(data, device) | |||||
@property | |||||
def device(self): | |||||
return self._opr.outputs[0].comp_node if self._opr else None | |||||
@device.setter | |||||
def device(self, device): | |||||
self.set_value(self.numpy(), device) | |||||
@property | |||||
def shape(self): | |||||
return self.outputs[0].shape | |||||
@property | |||||
def dtype(self): | |||||
return self._opr.outputs[0].dtype if self._opr else None | |||||
def numpy(self): | |||||
return self._opr.outputs[0].value if self._opr else None | |||||
def set_value(self, data, device=None): | |||||
assert self.graph is not None | |||||
cn = device if device else self.device | |||||
assert isinstance(data, (int, float, np.ndarray)) | |||||
if isinstance(data, (int, float)): | |||||
data = np.array(data) | |||||
if data.dtype == np.float64: | |||||
data = data.astype(np.float32) | |||||
elif data.dtype == np.int64: | |||||
data = data.astype(np.int32) | |||||
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | |||||
if len(self.outputs) == 0: | |||||
self.outputs.append(VarNode(self, self.name)) | |||||
self.outputs[0].var = varnode | |||||
self._opr = varnode.owner | |||||
@classmethod | |||||
def load(cls, opr): | |||||
self = cls() | |||||
self.outputs = [] | |||||
self._opr = opr | |||||
self.name = opr.outputs[0].name | |||||
self.graph = opr.graph | |||||
return self | |||||
def compile(self, graph): | |||||
assert self.outputs[0].var is self._opr.outputs[0] | |||||
assert self.outputs[0].owner is self | |||||
if self.graph != graph: | |||||
self.graph = graph | |||||
self.set_value(self.numpy()) | |||||
if self.name is not None: | |||||
self.outputs[0].var.name = self.name | |||||
class ReadOnlyOpNode(OpNode): | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(ReadOnlyOpNode, cls).load(opr) | |||||
obj.type = opr.type | |||||
return obj | |||||
def compile(self): | |||||
assert self._opr is not None | |||||
assert len(self.inputs) == len(self._opr.inputs) | |||||
assert len(self.outputs) == len(self._opr.outputs) | |||||
repl_dict = {} | |||||
for ind, i in enumerate(self.inputs): | |||||
if i.var != self._opr.inputs[ind]: | |||||
repl_dict[self._opr.inputs[ind]] = i.var | |||||
if bool(repl_dict): | |||||
out_vars = replace_vars(self._opr.outputs, repl_dict) | |||||
for ind, o in enumerate(self.outputs): | |||||
o.var = out_vars[ind] | |||||
class Elemwise(OpNode): | |||||
type = "Elemwise" | |||||
opdef = builtin.Elemwise | |||||
class Reduce(OpNode): | |||||
type = "Reduce" | |||||
opdef = builtin.Reduce | |||||
class TypeCvt(OpNode): | |||||
type = "TypeCvt" | |||||
opdef = builtin.TypeCvt | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(TypeCvt, cls).load(opr) | |||||
t_dtype = opr.outputs[0].dtype | |||||
obj.params["dtype"] = t_dtype | |||||
return obj | |||||
class MatrixInverse(OpNode): | |||||
type = "MatrixInverse" | |||||
opdef = builtin.MatrixInverse | |||||
class MatrixMul(OpNode): | |||||
type = "MatrixMul" | |||||
opdef = builtin.MatrixMul | |||||
class BatchedMatrixMul(OpNode): | |||||
type = "BatchedMatmul" | |||||
opdef = builtin.BatchedMatrixMul | |||||
class Dot(OpNode): | |||||
type = "Dot" | |||||
opdef = builtin.Dot | |||||
class SVD(OpNode): | |||||
type = "SVD" | |||||
opdef = builtin.SVD | |||||
class ConvolutionForward(OpNode): | |||||
type = "Convolution" | |||||
opdef = builtin.Convolution | |||||
class ConvolutionBackwardData(OpNode): | |||||
type = "ConvTranspose" | |||||
opdef = builtin.ConvolutionBackwardData | |||||
class DeformableConvForward(OpNode): | |||||
type = "DeformableConv" | |||||
opdef = builtin.DeformableConv | |||||
class GroupLocalForward(OpNode): | |||||
type = "GroupLocal" | |||||
opdef = builtin.GroupLocal | |||||
class PoolingForward(OpNode): | |||||
type = "Pooling" | |||||
opdef = builtin.Pooling | |||||
class AdaptivePoolingForward(OpNode): | |||||
type = "AdaptivePooling" | |||||
opdef = builtin.AdaptivePooling | |||||
class ROIPoolingForward(OpNode): | |||||
type = "ROIPooling" | |||||
opdef = builtin.ROIPooling | |||||
class DeformablePSROIPoolingForward(OpNode): | |||||
type = "DeformablePSROIPooling" | |||||
opdef = builtin.DeformablePSROIPooling | |||||
class ConvBiasForward(OpNode): | |||||
type = "ConvBias" | |||||
opdef = builtin.ConvBias | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(ConvBiasForward, cls).load(opr) | |||||
obj.params["dtype"] = opr.outputs[0].dtype | |||||
return obj | |||||
class BatchConvBiasForward(OpNode): | |||||
type = "BatchConvBias" | |||||
opdef = builtin.BatchConvBias | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(BatchConvBiasForward, cls).load(opr) | |||||
obj.params["dtype"] = opr.outputs[0].dtype | |||||
return obj | |||||
class BatchNormForward(OpNode): | |||||
type = "BatchNorm" | |||||
opdef = builtin.BatchNorm | |||||
class ROIAlignForward(OpNode): | |||||
type = "ROIAlign" | |||||
opdef = builtin.ROIAlign | |||||
class WarpPerspectiveForward(OpNode): | |||||
type = "WarpPerspective" | |||||
opdef = builtin.WarpPerspective | |||||
class WarpAffineForward(OpNode): | |||||
type = "WarpAffine" | |||||
opdef = builtin.WarpAffine | |||||
class RemapForward(OpNode): | |||||
type = "Remap" | |||||
opdef = builtin.Remap | |||||
class ResizeForward(OpNode): | |||||
type = "Resize" | |||||
opdef = builtin.Resize | |||||
class IndexingOneHot(OpNode): | |||||
type = "IndexingOneHot" | |||||
opdef = builtin.IndexingOneHot | |||||
class IndexingSetOneHot(OpNode): | |||||
type = "IndexingSetOneHot" | |||||
opdef = builtin.IndexingSetOneHot | |||||
class Copy(OpNode): | |||||
type = "Copy" | |||||
opdef = builtin.Copy | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(Copy, cls).load(opr) | |||||
obj.params["comp_node"] = opr.outputs[0].comp_node | |||||
return obj | |||||
class ArgsortForward(OpNode): | |||||
type = "Argsort" | |||||
opdef = builtin.Argsort | |||||
class Argmax(OpNode): | |||||
type = "Argmax" | |||||
opdef = builtin.Argmax | |||||
class Argmin(OpNode): | |||||
type = "Argmin" | |||||
opdef = builtin.Argmin | |||||
class CondTake(OpNode): | |||||
type = "CondTake" | |||||
opdef = builtin.CondTake | |||||
class TopK(OpNode): | |||||
type = "TopK" | |||||
opdef = builtin.TopK | |||||
class NvOf(OpNode): | |||||
type = "NvOf" | |||||
opdef = builtin.NvOf | |||||
class RNGOpr(OpNode): | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(RNGOpr, cls).load(opr) | |||||
if len(obj.params) == 3: | |||||
obj.opdef = builtin.GaussianRNG | |||||
obj.type = "GaussianRNG" | |||||
else: | |||||
obj.opdef = builtin.UniformRNG | |||||
obj.type = "UniformRNG" | |||||
return obj | |||||
class Linspace(OpNode): | |||||
type = "Linspace" | |||||
opdef = builtin.Linspace | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(Linspace, cls).load(opr) | |||||
obj.params["comp_node"] = opr.outputs[0].comp_node | |||||
return obj | |||||
class Eye(OpNode): | |||||
type = "Eye" | |||||
opdef = builtin.Eye | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(Eye, cls).load(opr) | |||||
obj.params["dtype"] = opr.outputs[0].dtype | |||||
obj.params["comp_node"] = opr.outputs[0].comp_node | |||||
return obj | |||||
class GetVarShape(OpNode): | |||||
type = "GetVarShape" | |||||
opdef = builtin.GetVarShape | |||||
class Concat(OpNode): | |||||
type = "Concat" | |||||
opdef = builtin.Concat | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(Concat, cls).load(opr) | |||||
obj.params["comp_node"] = Device("xpux").to_c() | |||||
return obj | |||||
class Broadcast(OpNode): | |||||
type = "Broadcast" | |||||
opdef = builtin.Broadcast | |||||
class Identity(OpNode): | |||||
type = "Identity" | |||||
opdef = builtin.Identity | |||||
class NMSKeep(OpNode): | |||||
type = "NMSKeep" | |||||
opdef = builtin.NMSKeep | |||||
# class ParamPackSplit | |||||
# class ParamPackConcat | |||||
class Dimshuffle(OpNode): | |||||
type = "Dimshuffle" | |||||
opdef = builtin.Dimshuffle | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(Dimshuffle, cls).load(opr) | |||||
del obj.params["ndim"] | |||||
return obj | |||||
class Reshape(OpNode): | |||||
type = "Reshape" | |||||
opdef = builtin.Reshape | |||||
class AxisAddRemove(OpNode): | |||||
type = "AxisAddRemove" | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = cls() | |||||
obj.name = opr.name | |||||
obj._opr = opr | |||||
params = json.loads(opr.params) | |||||
desc = params["desc"] | |||||
method = None | |||||
axis = [] | |||||
for i in desc: | |||||
if method is None: | |||||
method = i["method"] | |||||
assert method == i["method"] | |||||
axis.append(i["axisnum"]) | |||||
obj.params = {"axis": axis} | |||||
obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis | |||||
return obj | |||||
class IndexingBase(OpNode): | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = cls() | |||||
obj.name = opr.name | |||||
obj._opr = opr | |||||
params = json.loads(opr.params) | |||||
items = [ | |||||
[ | |||||
p["axis"], | |||||
bool(p["begin"]), | |||||
bool(p["end"]), | |||||
bool(p["step"]), | |||||
bool(p["idx"]), | |||||
] | |||||
for p in params | |||||
] | |||||
obj.params["items"] = items | |||||
return obj | |||||
class Subtensor(IndexingBase): | |||||
type = "Subtensor" | |||||
opdef = builtin.Subtensor | |||||
class SetSubtensor(IndexingBase): | |||||
type = "SetSubtensor" | |||||
opdef = builtin.SetSubtensor | |||||
class IncrSubtensor(IndexingBase): | |||||
type = "IncrSubtensor" | |||||
opdef = builtin.IncrSubtensor | |||||
class IndexingMultiAxisVec(IndexingBase): | |||||
type = "IndexingMultiAxisVec" | |||||
opdef = builtin.IndexingMultiAxisVec | |||||
class IndexingSetMultiAxisVec(IndexingBase): | |||||
type = "IndexingSetMultiAxisVec" | |||||
opdef = builtin.IndexingSetMultiAxisVec | |||||
class IndexingIncrMultiAxisVec(IndexingBase): | |||||
type = "IndexingIncrMultiAxisVec" | |||||
opdef = builtin.IndexingIncrMultiAxisVec | |||||
class MeshIndexing(IndexingBase): | |||||
type = "MeshIndexing" | |||||
opdef = builtin.MeshIndexing | |||||
class SetMeshIndexing(IndexingBase): | |||||
type = "SetMeshIndexing" | |||||
opdef = builtin.SetMeshIndexing | |||||
class IncrMeshIndexing(IndexingBase): | |||||
type = "IncrMeshIndexing" | |||||
opdef = builtin.IncrMeshIndexing | |||||
class BatchedMeshIndexing(IndexingBase): | |||||
type = "BatchedMeshIndexing" | |||||
opdef = builtin.BatchedMeshIndexing | |||||
class BatchedSetMeshIndexing(IndexingBase): | |||||
type = "BatchedSetMeshIndexing" | |||||
opdef = builtin.BatchedSetMeshIndexing | |||||
class BatchedIncrMeshIndexing(IndexingBase): | |||||
type = "BatchedIncrMeshIndexing" | |||||
opdef = builtin.BatchedIncrMeshIndexing | |||||
# class CollectiveComm | |||||
# class RemoteSend | |||||
# class RemoteRecv | |||||
# class TQT | |||||
# class FakeQuant | |||||
# class InplaceAdd | |||||
class AssertEqual(OpNode): | |||||
type = "AssertEqual" | |||||
opdef = builtin.AssertEqual | |||||
class ElemwiseMultiType(OpNode): | |||||
type = "ElemwiseMultiType" | |||||
opdef = builtin.ElemwiseMultiType | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(ElemwiseMultiType, cls).load(opr) | |||||
obj.params["dtype"] = opr.outputs[0].dtype | |||||
return obj | |||||
class CvtColorForward(OpNode): | |||||
type = "CvtColor" | |||||
opdef = builtin.CvtColor |
@@ -160,6 +160,16 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
if (ctx.op->same_type<BackwardGraph>()) { | if (ctx.op->same_type<BackwardGraph>()) { | ||||
ctx.backward = true; | ctx.backward = true; | ||||
} | } | ||||
if (py::isinstance<cg::VarNode>(py::handle(args[0]))){ | |||||
SmallVector<cg::VarNode*> vinputs(nargs); | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>(); | |||||
} | |||||
auto op = ctx.op.get(); | |||||
return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr(); | |||||
} | |||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | ||||
@@ -675,6 +685,16 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||||
tensors.emplace_back(descr); | tensors.emplace_back(descr); | ||||
continue; | continue; | ||||
} | } | ||||
if (py::isinstance<cg::VarNode>(py::handle(handle))){ | |||||
auto var = py::handle(handle).cast<cg::VarNode *>(); | |||||
mgb::DType type = var->dtype(); | |||||
auto && descr = npy::dtype_mgb2np_descr(type); | |||||
Py_INCREF(descr.get()); | |||||
tensors.emplace_back(descr.get()); | |||||
continue; | |||||
} | |||||
PyArray_Descr* descr = scalar2dtype(handle); | PyArray_Descr* descr = scalar2dtype(handle); | ||||
if (descr) { | if (descr) { | ||||
scalars.emplace_back(descr); | scalars.emplace_back(descr); | ||||
@@ -719,12 +739,14 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | ||||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | TensorWrapper* tw = TensorWrapper::try_cast(handle); | ||||
if (tw) { | |||||
bool is_var = py::isinstance<cg::VarNode>(py::handle(handle)); | |||||
if (tw || is_var) { | |||||
if (!valid) { | if (!valid) { | ||||
cn = tw->m_tensor->comp_node(); | |||||
cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||||
valid = true; | valid = true; | ||||
} else { | } else { | ||||
CompNode cn1 = tw->m_tensor->comp_node(); | |||||
CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||||
if (cn1 != cn) { | if (cn1 != cn) { | ||||
throw py::value_error(ssprintf("ambiguous device: %s vs %s", | throw py::value_error(ssprintf("ambiguous device: %s vs %s", | ||||
cn.to_string().c_str(), cn1.to_string().c_str())); | cn.to_string().c_str(), cn1.to_string().c_str())); | ||||
@@ -0,0 +1,351 @@ | |||||
import io | |||||
import numpy as np | |||||
import megengine.core.tensor.megbrain_graph as G | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
import megengine.utils.network_node as N | |||||
from megengine.jit.tracing import trace | |||||
from megengine.tensor import Tensor | |||||
from megengine.utils.comp_graph_tools import GraphInference | |||||
from megengine.utils.network import Network as Net | |||||
from megengine.utils.network import as_oprnode | |||||
from megengine.utils.network_node import Host2DeviceCopy, VarNode | |||||
def test_replace_var(): | |||||
a = Tensor([1, 2]) | |||||
b = Tensor([3, 4]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(a, b): | |||||
return (a + b) * 2 | |||||
fwd(a, b) | |||||
orig_model = io.BytesIO() | |||||
fwd.dump( | |||||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||||
) | |||||
orig_model.seek(0) | |||||
graph = Net.load(orig_model) | |||||
vara = graph.var_filter.name("a").as_unique() | |||||
varb = graph.var_filter.name("b").as_unique() | |||||
out = F.mul(vara.var, varb.var) | |||||
out = F.relu(out) | |||||
var_list = graph.add_dep_oprs(out) | |||||
opnode = list(graph.opr_filter.has_input(vara)) | |||||
repl_dict = {opnode[0].outputs[0]: var_list[0]} | |||||
graph.replace_vars(repl_dict) | |||||
modified_model = io.BytesIO() | |||||
graph.dump(modified_model) | |||||
modified_model.seek(0) | |||||
load_graph = GraphInference(modified_model) | |||||
out = load_graph.run(a, b) | |||||
np.testing.assert_equal(out["o"], [6, 16]) | |||||
def test_replace_opr(): | |||||
a = Tensor([1, 2]) | |||||
b = Tensor([3, 4]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(a, b): | |||||
return (a + b) * 2 | |||||
fwd(a, b) | |||||
orig_model = io.BytesIO() | |||||
fwd.dump( | |||||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||||
) | |||||
orig_model.seek(0) | |||||
graph = Net.load(orig_model) | |||||
vara = graph.var_filter.name("a").as_unique() | |||||
varb = graph.var_filter.name("b").as_unique() | |||||
out1 = F.sub(vara.var, varb.var) | |||||
out1 = F.relu(out1) | |||||
var_list = graph.add_dep_oprs(out1) | |||||
repl_opr = as_oprnode(var_list) | |||||
orig_opr = graph.opr_filter.has_input(vara).as_unique() | |||||
repl_dict = {orig_opr: repl_opr} | |||||
graph.replace_oprs(repl_dict) | |||||
modified_model1 = io.BytesIO() | |||||
graph.dump(modified_model1) | |||||
modified_model1.seek(0) | |||||
load_graph = GraphInference(modified_model1) | |||||
out = load_graph.run(a, b) | |||||
np.testing.assert_equal(out["o"], [0, 0]) | |||||
def test_modify_params(): | |||||
a = Tensor([1, 2]) | |||||
b = Tensor([3, 4]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(a, b): | |||||
return (a + b) * 2 | |||||
fwd(a, b) | |||||
orig_model = io.BytesIO() | |||||
fwd.dump( | |||||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||||
) | |||||
orig_model.seek(0) | |||||
graph = Net.load(orig_model) | |||||
param_const = graph.params_filter.as_unique() | |||||
param_const.set_value(3) | |||||
modified_model = io.BytesIO() | |||||
graph.dump(modified_model) | |||||
modified_model.seek(0) | |||||
load_graph = GraphInference(modified_model) | |||||
out = load_graph.run(a, b) | |||||
np.testing.assert_equal(out["o"], [12, 18]) | |||||
def test_make_const(): | |||||
a = Tensor([1, 2]) | |||||
b = Tensor([3, 4]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(a, b): | |||||
return (a + b) * 2 | |||||
fwd(a, b) | |||||
orig_model = io.BytesIO() | |||||
fwd.dump( | |||||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||||
) | |||||
orig_model.seek(0) | |||||
graph = Net.load(orig_model) | |||||
const_b = graph.make_const(np.array([0.0, 0.0]), name="b") | |||||
varb = graph.var_filter.name("b").as_unique() | |||||
repl_dict = {varb: const_b} | |||||
graph.replace_vars(repl_dict) | |||||
modified_model = io.BytesIO() | |||||
graph.dump(modified_model) | |||||
modified_model.seek(0) | |||||
load_graph = GraphInference(modified_model) | |||||
out = load_graph.run(a) | |||||
np.testing.assert_equal(out["o"], [2, 4]) | |||||
def test_add_input(): | |||||
a = Tensor([1, 2]) | |||||
b = Tensor([3, 4]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(a, b): | |||||
return (a + b) * 2 | |||||
fwd(a, b) | |||||
orig_model = io.BytesIO() | |||||
fwd.dump( | |||||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||||
) | |||||
orig_model.seek(0) | |||||
graph = Net.load(orig_model) | |||||
inp_c = graph.make_input_node((2,), np.int32, name="c") | |||||
varo = graph.var_filter.name("o").as_unique() | |||||
out = F.add(varo.var, inp_c.var) | |||||
out = graph.add_dep_oprs(out)[0] | |||||
out.name = "o1" | |||||
graph.remove_output(varo) | |||||
graph.add_output(out) | |||||
modified_model = io.BytesIO() | |||||
graph.dump(modified_model) | |||||
modified_model.seek(0) | |||||
load_graph = GraphInference(modified_model) | |||||
out = load_graph.run(a, b, a) | |||||
np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy()) | |||||
def test_add_output(): | |||||
a = Tensor([1.0, 2.0]) | |||||
b = Tensor([3.0, 4.0]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(a, b): | |||||
return (a + b) * 2 | |||||
fwd(a, b) | |||||
orig_model = io.BytesIO() | |||||
fwd.dump( | |||||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||||
) | |||||
orig_model.seek(0) | |||||
net = Net.load(orig_model) | |||||
var_a = net.var_filter.name("a").as_unique() | |||||
var_b = net.var_filter.name("b").as_unique() | |||||
y = F.add(var_a.var, var_b.var) | |||||
y = F.sigmoid(y) | |||||
new_vars = net.add_dep_oprs(y)[0] | |||||
new_vars.name = "o1" | |||||
net.add_output(new_vars) | |||||
modified_model = io.BytesIO() | |||||
net.dump(modified_model) | |||||
modified_model.seek(0) | |||||
g = GraphInference(modified_model) | |||||
out = g.run(a.numpy(), b.numpy()) | |||||
np.testing.assert_equal(out["o"], ((a + b) * 2).numpy()) | |||||
np.testing.assert_equal(out["o1"], (F.sigmoid((a + b))).numpy()) | |||||
def test_query(): | |||||
class Model(M.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.conv1 = M.Conv2d(3, 32, 3) | |||||
self.conv2 = M.Conv2d(32, 32, 3) | |||||
self.conv3 = M.Conv2d(32, 32, 3) | |||||
def forward(self, data): | |||||
x = self.conv1(data) | |||||
x = self.conv2(x) | |||||
x = self.conv3(x) | |||||
return x | |||||
n = Model() | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return n(data) | |||||
fwd(Tensor(np.random.random((1, 3, 224, 224)))) | |||||
orig_model = io.BytesIO() | |||||
fwd.dump( | |||||
orig_model, | |||||
arg_names=["data"], | |||||
output_names="o", | |||||
keep_opr_name=True, | |||||
keep_var_name=True, | |||||
optimize_for_inference=False, | |||||
) | |||||
orig_model.seek(0) | |||||
graph = Net.load(orig_model) | |||||
r = graph.data_providers_filter.as_count() | |||||
assert r == 1 | |||||
opr = graph.get_opr_by_type(Host2DeviceCopy) | |||||
assert isinstance(opr, Host2DeviceCopy) | |||||
r1 = graph.params_filter.as_count() | |||||
assert r1 == 6 | |||||
r2 = graph.opr_filter.type(N.ConvolutionForward).as_count() | |||||
assert r2 == 3 | |||||
r3 = graph.opr_filter.not_type(N.ConvolutionForward).as_count() | |||||
assert r3 == len(graph.all_oprs) - r2 | |||||
var = graph.var_filter.name("data").as_unique() | |||||
r4 = graph.opr_filter.has_input(var).as_count() | |||||
assert r4 == 1 | |||||
r5 = graph.opr_filter.name("data").as_count() | |||||
assert r5 == 1 | |||||
opr = graph.get_opr_by_name("data") | |||||
assert isinstance(opr, Host2DeviceCopy) | |||||
var = graph.get_var_by_name("data") | |||||
assert isinstance(var, VarNode) | |||||
r6 = graph.var_filter.name("*bias").as_count() | |||||
assert r6 == 3 | |||||
def test_optimize_for_inference(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def f(x): | |||||
return F.exp(x) | |||||
orig_model = io.BytesIO() | |||||
f(Tensor(5.0)) | |||||
f.dump(orig_model, optimize_for_inference=False) | |||||
orig_model.seek(0) | |||||
optimize_model = io.BytesIO() | |||||
net = Net.load(orig_model) | |||||
net.dump(optimize_model, enable_io16xc32=True) | |||||
optimize_model.seek(0) | |||||
res = G.load_graph(optimize_model) | |||||
computing_input = res.output_vars_list[0].owner.inputs[0] | |||||
assert computing_input.dtype == np.float16 | |||||
def test_reset_batchsize(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def f(x): | |||||
return F.exp(x) | |||||
orig_model = io.BytesIO() | |||||
f(Tensor(np.random.random((3, 3, 224, 224)))) | |||||
f.dump(orig_model, optimize_for_inference=False) | |||||
orig_model.seek(0) | |||||
modified_model = io.BytesIO() | |||||
net = Net.load(orig_model) | |||||
net.reset_batch_size(1) | |||||
net.dump(modified_model, optimize_for_inference=False) | |||||
modified_model.seek(0) | |||||
net1 = Net.load(modified_model) | |||||
assert net1.data_providers_filter.as_unique().shape[0] == 1 | |||||
def test_modify_opr_name(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def f(x): | |||||
return F.exp(x) | |||||
orig_model = io.BytesIO() | |||||
f(Tensor(np.random.random((3, 3, 224, 224)))) | |||||
f.dump(orig_model, arg_names=["a"], optimize_for_inference=False) | |||||
orig_model.seek(0) | |||||
modified_model = io.BytesIO() | |||||
net = Net.load(orig_model) | |||||
net.modify_opr_names("net") | |||||
net.modify_opr_names(lambda x: "net1." + x) | |||||
net.dump(modified_model, optimize_for_inference=False) | |||||
modified_model.seek(0) | |||||
net1 = Net.load(modified_model) | |||||
assert net1.data_providers_filter.as_unique().name == "net1.net.a" |
@@ -0,0 +1,712 @@ | |||||
import io | |||||
import os | |||||
import platform | |||||
import numpy as np | |||||
import pytest | |||||
import megengine.core.tensor.dtype as dtype | |||||
import megengine.core.tensor.megbrain_graph as G | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
import megengine.random as rand | |||||
from megengine.core._imperative_rt.core2 import apply | |||||
from megengine.core._wrap import Device | |||||
from megengine.core.ops import builtin | |||||
from megengine.device import is_cuda_available | |||||
from megengine.functional.external import tensorrt_runtime_opr | |||||
from megengine.jit.tracing import trace | |||||
from megengine.tensor import Tensor | |||||
from megengine.utils.comp_graph_tools import GraphInference | |||||
from megengine.utils.network import Network as Net | |||||
def check_pygraph_dump(trace_func, inp_data, expect_results): | |||||
orig_model = io.BytesIO() | |||||
inp_size = len(inp_data) | |||||
out_size = len(expect_results) | |||||
arg_names = ["arg_{}".format(i) for i in range(inp_size)] | |||||
output_names = ["out_{}".format(i) for i in range(out_size)] | |||||
trace_func.dump( | |||||
orig_model, | |||||
arg_names=arg_names, | |||||
output_names=output_names, | |||||
optimize_for_inference=False, | |||||
) | |||||
orig_model.seek(0) | |||||
net = Net.load(orig_model) | |||||
file = io.BytesIO() | |||||
net.dump(file, optimize_for_inference=False) | |||||
file.seek(0) | |||||
graph = GraphInference(file) | |||||
inp_dict = dict([(arg_names[i], inp_data[i].numpy()) for i in range(inp_size)]) | |||||
results = graph.run(inp_dict=inp_dict) | |||||
for ind, tensor in enumerate(expect_results): | |||||
np.testing.assert_equal(tensor.numpy(), results[output_names[ind]]) | |||||
assert tensor.dtype == results[output_names[ind]].dtype | |||||
def test_elemwise(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x, y): | |||||
z1 = x * y | |||||
z2 = x + y | |||||
z3 = z1 / z2 | |||||
z3 = z3 ** 3 | |||||
return z3 | |||||
x = Tensor([1.0, 2.0]) | |||||
y = Tensor([3.0, 5.0]) | |||||
result = fwd(x, y) | |||||
check_pygraph_dump(fwd, [x, y], [result]) | |||||
def test_reduce(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
x = data.sum(axis=2) | |||||
x = x.mean(axis=1) | |||||
return x | |||||
data = Tensor(np.random.random((1, 32, 32))) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
def test_typecvt(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return data.astype(dtype.qint8(0.8)) | |||||
x = Tensor(np.random.random((2, 3)) * 255) | |||||
result = fwd(x) | |||||
check_pygraph_dump(fwd, [x], [result]) | |||||
def test_matinv(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return F.matinv(data) | |||||
data = Tensor(np.random.random((5, 5))) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
def test_matmul(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data1, data2): | |||||
return F.matmul(data1, data2) | |||||
data1 = Tensor(np.random.random((32, 64))) | |||||
data2 = Tensor(np.random.random((64, 16))) | |||||
result = fwd(data1, data2) | |||||
check_pygraph_dump(fwd, [data1, data2], [result]) | |||||
def test_batchmatmul(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x, y): | |||||
return F.matmul(x, y) | |||||
x = Tensor(np.random.random((3, 3, 5))) | |||||
y = Tensor(np.random.random((3, 5, 3))) | |||||
result = fwd(x, y) | |||||
check_pygraph_dump(fwd, [x, y], [result]) | |||||
def test_dot(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x, y): | |||||
return F.dot(x, y) | |||||
x = Tensor([1.0, 2.0, 3.0]) | |||||
y = Tensor([3.0, 4.0, 5.0]) | |||||
result = fwd(x, y) | |||||
check_pygraph_dump(fwd, [x, y], [result]) | |||||
def test_svd(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
_, out, _ = F.svd(data) | |||||
return out | |||||
input = Tensor(np.random.random((1, 1, 3, 3))) | |||||
result = fwd(input) | |||||
check_pygraph_dump(fwd, [input], [result]) | |||||
def test_conv(): | |||||
conv = M.Conv2d(3, 32, 3) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return conv(data) | |||||
data = Tensor(np.random.random((1, 3, 32, 32))) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
def test_deformable_conv(): | |||||
if not is_cuda_available(): | |||||
return | |||||
conv = M.DeformableConv2d(3, 32, 3) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data, offset, mask): | |||||
return conv(data, offset, mask) | |||||
data = Tensor(np.random.random((1, 3, 32, 32))) | |||||
offset = Tensor(np.ones((32, 3 * 3 * 2, 30, 30)).astype("int32") * 5) | |||||
mask = Tensor(np.ones((32, 3 * 3, 30, 30)).astype("int32")) | |||||
out = fwd(data, offset, mask) | |||||
check_pygraph_dump(fwd, [data, offset, mask], [out]) | |||||
def test_convtranspose(): | |||||
deconv = M.ConvTranspose2d(32, 32, 3) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return deconv(data) | |||||
data = Tensor(np.random.random((1, 32, 32, 32))) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
@pytest.mark.skip(reason="pytest aborted") | |||||
def test_grouplocal(): | |||||
n = M.LocalConv2d(3, 32, 32, 32, 3) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return n(data) | |||||
input = Tensor(np.random.random((1, 3, 32, 32))) | |||||
result = fwd(input) | |||||
check_pygraph_dump(fwd, [input], [result]) | |||||
def test_pooling(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
out = F.max_pool2d(data, 2, 2) | |||||
out = F.avg_pool2d(out, 2, 2) | |||||
return out | |||||
data = Tensor(np.random.random((1, 3, 64, 64))) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
def test_adaptivepooling(): | |||||
pool1 = M.AdaptiveMaxPool2d((2, 2)) | |||||
pool2 = M.AdaptiveAvgPool2d((2, 2)) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
out = pool1(data) | |||||
out = pool2(out) | |||||
return out | |||||
input = Tensor(np.random.random((1, 3, 32, 32))) | |||||
result = fwd(input) | |||||
check_pygraph_dump(fwd, [input], [result]) | |||||
def test_roipooling(): | |||||
inp = Tensor(np.random.random((1, 1, 128, 128))) | |||||
rois = Tensor(np.random.random((4, 5))) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(inp, rois): | |||||
return F.nn.roi_pooling(inp, rois, (2, 2), scale=2.0) | |||||
output = fwd(inp, rois) | |||||
check_pygraph_dump(fwd, [inp, rois], [output]) | |||||
def test_deformable_ps_roi_pooling(): | |||||
inp = Tensor(np.random.random((1, 256, 64, 64)).astype("float32")) | |||||
rois = Tensor(np.random.random((1, 5)).astype("float32")) | |||||
trans = Tensor(np.random.random((24, 2, 7, 7)).astype("float32")) | |||||
pooled_h = 7 | |||||
pooled_w = 7 | |||||
sample_per_part = 4 | |||||
no_trans = False | |||||
part_size = 7 | |||||
spatial_scale = 1.0 / 64 | |||||
trans_std = 0.1 | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(inp, rois, trans): | |||||
y = F.deformable_psroi_pooling( | |||||
inp, | |||||
rois, | |||||
trans, | |||||
no_trans, | |||||
part_size, | |||||
pooled_h, | |||||
pooled_w, | |||||
sample_per_part, | |||||
spatial_scale, | |||||
trans_std, | |||||
) | |||||
return y | |||||
result = fwd(inp, rois, trans) | |||||
check_pygraph_dump(fwd, [inp, rois, trans], [result]) | |||||
def test_convbias(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(inp, weight, bias): | |||||
return F.quantized.conv_bias_activation( | |||||
inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" | |||||
) | |||||
inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) | |||||
weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0)) | |||||
bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0)) | |||||
result = fwd(inp, weight, bias) | |||||
check_pygraph_dump(fwd, [inp, weight, bias], [result]) | |||||
def test_batch_convbias(): | |||||
if is_cuda_available(): | |||||
return | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(inp, weight, bias): | |||||
return F.quantized.batch_conv_bias_activation( | |||||
inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" | |||||
) | |||||
inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) | |||||
weight = Tensor(np.random.random((1, 32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0)) | |||||
bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0)) | |||||
result = fwd(inp, weight, bias) | |||||
check_pygraph_dump(fwd, [inp, weight, bias], [result]) | |||||
def test_batchnorm(): | |||||
bn = M.BatchNorm2d(32) | |||||
bn.eval() | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return bn(data) | |||||
data = Tensor(np.random.random((1, 32, 32, 32))) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
def test_roialign(): | |||||
inp = Tensor(np.random.randn(1, 1, 128, 128)) | |||||
rois = Tensor(np.random.random((4, 5))) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(inp, rois): | |||||
return F.nn.roi_align(inp, rois, (2, 2)) | |||||
output = fwd(inp, rois) | |||||
check_pygraph_dump(fwd, [inp, rois], [output]) | |||||
def test_warpperspective(): | |||||
inp_shape = (1, 1, 4, 4) | |||||
x = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||||
M_shape = (1, 3, 3) | |||||
# M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) | |||||
M = Tensor( | |||||
np.array( | |||||
[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32 | |||||
).reshape(M_shape) | |||||
) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x, M): | |||||
return F.warp_perspective(x, M, (2, 2)) | |||||
result = fwd(x, M) | |||||
check_pygraph_dump(fwd, [x, M], [result]) | |||||
def test_warpaffine(): | |||||
inp_shape = (1, 3, 3, 3) | |||||
x = Tensor(np.arange(27, dtype=np.float32).reshape(inp_shape)) | |||||
weightv = Tensor([[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x, weightv): | |||||
return F.warp_affine(x, weightv, (2, 2), border_mode="WRAP") | |||||
outp = fwd(x, weightv) | |||||
check_pygraph_dump(fwd, [x, weightv], [outp]) | |||||
def test_remap(): | |||||
inp_shape = (1, 1, 4, 4) | |||||
inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||||
map_xy_shape = (1, 2, 2, 2) | |||||
map_xy = Tensor( | |||||
np.array( | |||||
[[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32 | |||||
).reshape(map_xy_shape) | |||||
) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(inp, map_xy): | |||||
return F.remap(inp, map_xy) | |||||
out = fwd(inp, map_xy) | |||||
check_pygraph_dump(fwd, [inp, map_xy], [out]) | |||||
def test_resize(): | |||||
x = Tensor(np.random.randn(10, 3, 32, 32)) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x): | |||||
return F.nn.interpolate(x, size=(16, 16), mode="BILINEAR") | |||||
out = fwd(x) | |||||
check_pygraph_dump(fwd, [x], [out]) | |||||
def test_index_onehot(): | |||||
src = Tensor([[1.0, 2.0]]) | |||||
index = Tensor([0]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(src, index): | |||||
return F.indexing_one_hot(src, index) | |||||
out = fwd(src, index) | |||||
check_pygraph_dump(fwd, [src, index], [out]) | |||||
def test_set_onehot(): | |||||
x = Tensor(np.arange(1, 4, dtype=np.int32)) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x): | |||||
return F.one_hot(x, num_classes=4) | |||||
out = fwd(x) | |||||
check_pygraph_dump(fwd, [x], [out]) | |||||
def test_copy(): | |||||
x = Tensor([1, 2, 3]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x): | |||||
return x.to("cpu0:0") | |||||
o = fwd(x) | |||||
check_pygraph_dump(fwd, [x], [o]) | |||||
def test_argsort(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return F.argsort(data, True) | |||||
data = Tensor([1.0, 2.0, 3.0, 5.0]) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
def test_argmax_min(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return F.argmax(data), F.argmin(data) | |||||
data = Tensor(np.random.random((10, 10))) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], result) | |||||
def test_condtake(): | |||||
mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) | |||||
x = Tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(mask, x): | |||||
v, index = F.cond_take(mask, x) | |||||
return v, index | |||||
v, index = fwd(mask, x) | |||||
check_pygraph_dump(fwd, [mask, x], [v, index]) | |||||
def test_topk(): | |||||
x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x): | |||||
top, indices = F.topk(x, 5) | |||||
return top, indices | |||||
top, indices = fwd(x) | |||||
check_pygraph_dump(fwd, [x], [top, indices]) | |||||
def test_random(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(): | |||||
x = rand.uniform(size=(2, 2)) | |||||
y = rand.normal(size=(1, 3, 3, 3)) | |||||
return x, y | |||||
x, y = fwd() | |||||
check_pygraph_dump(fwd, [], [x, y]) | |||||
def test_tensor_gen(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(): | |||||
a = F.linspace(3, 10, 3, device=Device("xpux").to_c()) | |||||
b = F.eye(3, device=Device("xpux").to_c()) | |||||
return a, b | |||||
a, b = fwd() | |||||
check_pygraph_dump(fwd, [], [a, b]) | |||||
def test_getvarshape(): | |||||
op = builtin.GetVarShape(axis=1) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return apply(op, data)[0] | |||||
data = Tensor(np.random.random((1, 2, 3, 4))) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
def test_concat(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data1, data2): | |||||
return F.concat([data1, data2], axis=1) | |||||
x = Tensor(np.random.random((2, 3))) | |||||
y = Tensor(np.random.random((2, 5))) | |||||
result = fwd(x, y) | |||||
check_pygraph_dump(fwd, [x, y], [result]) | |||||
def test_broadcast(): | |||||
inp = Tensor([[1], [2], [3], [4]]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(inp): | |||||
return F.broadcast_to(inp, (4, 4)) | |||||
out = fwd(inp) | |||||
check_pygraph_dump(fwd, [inp], [out]) | |||||
def test_identity(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return F.copy(data) | |||||
data = Tensor([1.0, 2.0]) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
@pytest.mark.skip(reason="advance indexing trace error") | |||||
def test_nms(): | |||||
x = np.zeros((100, 4)) | |||||
np.random.seed(42) | |||||
x[:, :2] = np.random.rand(100, 2) * 20 | |||||
x[:, 2:] = np.random.rand(100, 2) * 20 + 100 | |||||
scores = Tensor(np.random.rand(100)) | |||||
inp = Tensor(x) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(inp, scores): | |||||
return F.nn.nms(inp, scores, iou_thresh=0.7, max_output=3) | |||||
result = fwd(inp, scores) | |||||
check_pygraph_dump(fwd, [inp, scores], [result]) | |||||
def test_dimshuffle(): | |||||
inp = Tensor([1, 2, 3, 4]) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(inp): | |||||
return inp.T | |||||
out = fwd(inp) | |||||
check_pygraph_dump(fwd, [inp], [out]) | |||||
def test_reshape(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
return data.reshape((1, 8)) | |||||
data = Tensor(np.random.random((1, 2, 2, 2))) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
def test_add_remove_axis(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(data): | |||||
x = F.expand_dims(data, [0, 0]) | |||||
y = F.squeeze(x, 0) | |||||
return y | |||||
data = Tensor([1.0, 2.0]) | |||||
result = fwd(data) | |||||
check_pygraph_dump(fwd, [data], [result]) | |||||
@pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||||
def test_subtensor(mode): | |||||
items = [[0, True, True, True, False], [1, False, False, False, True]] | |||||
data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random(2))] | |||||
if mode == "get": | |||||
op = builtin.Subtensor(items) | |||||
data = data[:1] | |||||
if mode == "set": | |||||
op = builtin.SetSubtensor(items) | |||||
if mode == "inc": | |||||
op = builtin.IncrSubtensor(items) | |||||
tensors = [Tensor(0), Tensor(4), Tensor(2), Tensor(3)] | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(*tensors): | |||||
return apply(op, *tensors)[0] | |||||
result = fwd(*data, *tensors) | |||||
check_pygraph_dump(fwd, data + tensors, [result]) | |||||
@pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||||
def test_advance_indexing(mode): | |||||
items = [[0, False, False, False, True]] | |||||
tensors = [Tensor([0, 4, 2])] | |||||
data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 5)))] | |||||
if mode == "get": | |||||
op = builtin.IndexingMultiAxisVec(items) | |||||
data = data[:1] | |||||
if mode == "set": | |||||
op = builtin.IndexingSetMultiAxisVec(items) | |||||
if mode == "inc": | |||||
op = builtin.IndexingIncrMultiAxisVec(items) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(*tensors): | |||||
return apply(op, *tensors)[0] | |||||
result = fwd(*data, *tensors) | |||||
check_pygraph_dump(fwd, data + tensors, [result]) | |||||
@pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||||
def test_mesh_indexing(mode): | |||||
items = [[0, True, True, True, False], [1, False, False, False, True]] | |||||
tensors = [Tensor(0), Tensor(5), Tensor(2), Tensor([1, 3])] | |||||
data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 2)))] | |||||
if mode == "get": | |||||
op = builtin.IndexingMultiAxisVec(items) | |||||
data = data[:1] | |||||
if mode == "set": | |||||
op = builtin.IndexingSetMultiAxisVec(items) | |||||
if mode == "inc": | |||||
op = builtin.IndexingIncrMultiAxisVec(items) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(*tensors): | |||||
return apply(op, *tensors)[0] | |||||
result = fwd(*data, *tensors) | |||||
check_pygraph_dump(fwd, data + tensors, [result]) | |||||
@pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||||
def test_batch_mesh_indexing(mode): | |||||
items = [[1, False, False, False, True], [2, False, False, False, True]] | |||||
tensors = [Tensor([[0, 2], [0, 2]]), Tensor([[0, 1, 2], [1, 2, 3]])] | |||||
data = [Tensor(np.random.random((2, 3, 4))), Tensor(np.random.random((2, 2, 3)))] | |||||
if mode == "get": | |||||
op = builtin.BatchedMeshIndexing(items) | |||||
data = data[:1] | |||||
if mode == "set": | |||||
op = builtin.BatchedSetMeshIndexing(items) | |||||
if mode == "inc": | |||||
op = builtin.BatchedIncrMeshIndexing(items) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(*tensors): | |||||
return apply(op, *tensors)[0] | |||||
result = fwd(*data, *tensors) | |||||
check_pygraph_dump(fwd, data + tensors, [result]) | |||||
@pytest.mark.skip(reason="tmp skip") | |||||
def test_assert_equal(): | |||||
g = G.Graph() | |||||
inp1 = g.make_h2d(dtype=np.float32, device="xpux") | |||||
inp2 = g.make_h2d(dtype=np.float32, device="xpux") | |||||
op = builtin.AssertEqual(maxerr=1e-5) | |||||
out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0] | |||||
print(out) | |||||
g.compile(out) | |||||
file = io.BytesIO() | |||||
out_model = G.dump_graph([out]) | |||||
file.write(out_model[0]) | |||||
file.seek(0) | |||||
net = Net.load(file) | |||||
dump_file = io.BytesIO() | |||||
net.dump(dump_file) | |||||
dump_file.seek(0) | |||||
g = GraphInference(dump_file) | |||||
g.run(np.array([1.0, 2.0]), np.array([1.0, 2.0])) | |||||
def test_elemwise_multitype(): | |||||
op = builtin.ElemwiseMultiType(mode="QADD", dtype=dtype.qint32(2.0)) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x, y): | |||||
return apply(op, x, y)[0] | |||||
x = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0)) | |||||
y = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0)) | |||||
result = fwd(x, y) | |||||
check_pygraph_dump(fwd, [x, y], [result]) | |||||
def test_cvtcolor(): | |||||
inp = np.random.randn(3, 3, 3, 3).astype(np.float32) | |||||
x = Tensor(inp) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(inp): | |||||
return F.img_proc.cvt_color(inp, mode="RGB2GRAY") | |||||
result = fwd(x) | |||||
check_pygraph_dump(fwd, [x], [result]) |
@@ -17,9 +17,20 @@ | |||||
#include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
#include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
#include "megbrain/opr/dnn/roi_pooling.h" | |||||
#include "megbrain/opr/dnn/roi_align.h" | |||||
#include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
#include "megbrain/opr/standalone/nms_opr.h" | |||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
#include "megbrain/opr/rand.h" | |||||
#include "megbrain/opr/dnn/batch_norm.h" | |||||
#include "megbrain/opr/misc.h" | |||||
#include "megbrain/opr/indexing.h" | |||||
#include "megbrain/opr/internal/indexing_helper.h" | |||||
#include "megbrain/opr/nn_int.h" | |||||
#include "megbrain/opr/tensor_gen.h" | |||||
#if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
#include "megdnn/opr_param_json.h" | #include "megdnn/opr_param_json.h" | ||||
#endif | #endif | ||||
@@ -354,7 +365,7 @@ uint64_t opr_footprint_func<opr::DeformableConvForward>( | |||||
auto&& out_shape = opr->output()[0]->shape(); | auto&& out_shape = opr->output()[0]->shape(); | ||||
auto&& filter_shape = opr->input()[1]->shape(); | auto&& filter_shape = opr->input()[1]->shape(); | ||||
using Param = opr::DeformableConvForward::Param; | using Param = opr::DeformableConvForward::Param; | ||||
auto&& param = opr->cast_final_safe<opr::Convolution>().param(); | |||||
auto&& param = opr->cast_final_safe<opr::DeformableConvForward>().param(); | |||||
size_t fh, fw, icpg; | size_t fh, fw, icpg; | ||||
mgb_assert(param.format == Param::Format::NCHW); | mgb_assert(param.format == Param::Format::NCHW); | ||||
if (param.sparse == Param::Sparse::GROUP) { | if (param.sparse == Param::Sparse::GROUP) { | ||||
@@ -425,9 +436,11 @@ uint64_t opr_footprint_func<opr::BatchConvBiasForward>( | |||||
auto&& filter_shape = opr->input()[1]->shape(); | auto&& filter_shape = opr->input()[1]->shape(); | ||||
using Param = opr::BatchConvBiasForward::Param; | using Param = opr::BatchConvBiasForward::Param; | ||||
auto&& param = opr->cast_final_safe<opr::BatchConvBiasForward>().param(); | auto&& param = opr->cast_final_safe<opr::BatchConvBiasForward>().param(); | ||||
mgb_assert(param.format == Param::Format::NCHW4); | |||||
size_t packed_channels = 4; | |||||
size_t packed_channels = 1; | |||||
size_t kern_spatial_pos = 3; | size_t kern_spatial_pos = 3; | ||||
if (param.format == Param::Format::NCHW4) { | |||||
packed_channels = 4; | |||||
} | |||||
size_t fh = filter_shape[kern_spatial_pos], | size_t fh = filter_shape[kern_spatial_pos], | ||||
fw = filter_shape[kern_spatial_pos + 1]; | fw = filter_shape[kern_spatial_pos + 1]; | ||||
return out_shape.total_nr_elems() * fh * fw * src_shape[1] * | return out_shape.total_nr_elems() * fh * fw * src_shape[1] * | ||||
@@ -508,7 +521,29 @@ REGISTE_PARAM_JSON_FUNC(LocalShareBackwardFilter) | |||||
REGISTE_PARAM_JSON_FUNC(DeformableConvForward) | REGISTE_PARAM_JSON_FUNC(DeformableConvForward) | ||||
REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardFilter) | REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardFilter) | ||||
REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardData) | REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardData) | ||||
REGISTE_PARAM_JSON_FUNC(DeformablePSROIPoolingForward) | |||||
REGISTE_PARAM_JSON_FUNC(BatchConvBiasForward) | REGISTE_PARAM_JSON_FUNC(BatchConvBiasForward) | ||||
REGISTE_PARAM_JSON_FUNC(BatchNormForward) | |||||
REGISTE_PARAM_JSON_FUNC(ElemwiseMultiType) | |||||
REGISTE_PARAM_JSON_FUNC(Argsort) | |||||
REGISTE_PARAM_JSON_FUNC(Argmax) | |||||
REGISTE_PARAM_JSON_FUNC(Argmin) | |||||
REGISTE_PARAM_JSON_FUNC(AdaptivePooling) | |||||
REGISTE_PARAM_JSON_FUNC(ROIPooling) | |||||
REGISTE_PARAM_JSON_FUNC(ROIAlign) | |||||
REGISTE_PARAM_JSON_FUNC(WarpPerspective) | |||||
REGISTE_PARAM_JSON_FUNC(WarpAffine) | |||||
REGISTE_PARAM_JSON_FUNC(Remap) | |||||
REGISTE_PARAM_JSON_FUNC(Resize) | |||||
REGISTE_PARAM_JSON_FUNC(IndexingOneHot) | |||||
REGISTE_PARAM_JSON_FUNC(IndexingSetOneHot) | |||||
REGISTE_PARAM_JSON_FUNC(TopK) | |||||
REGISTE_PARAM_JSON_FUNC(UniformRNG) | |||||
REGISTE_PARAM_JSON_FUNC(GaussianRNG) | |||||
REGISTE_PARAM_JSON_FUNC(Linspace) | |||||
REGISTE_PARAM_JSON_FUNC(Eye) | |||||
REGISTE_PARAM_JSON_FUNC(CvtColor) | |||||
template <> | template <> | ||||
std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>( | std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>( | ||||
@@ -547,24 +582,83 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>( | |||||
}); | }); | ||||
} | } | ||||
std::shared_ptr<json::Value> indexing_param_to_json( | |||||
const std::vector<opr::indexing::AxisIndexer>& indices) { | |||||
auto desc = json::Array::make(); | |||||
for (auto& index : indices) { | |||||
desc->add(json::Object::make({ | |||||
{"axis", json::NumberInt::make(index.axis.get_raw())}, | |||||
{"begin", | |||||
json::NumberInt::make(index.begin.node() != nullptr)}, | |||||
{"end", json::NumberInt::make(index.end.node() != nullptr)}, | |||||
{"step", | |||||
json::NumberInt::make(index.step.node() != nullptr)}, | |||||
{"idx", json::NumberInt::make(index.idx.node() != nullptr)}, | |||||
})); | |||||
} | |||||
return desc; | |||||
} | |||||
#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \ | |||||
template <> \ | |||||
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \ | |||||
cg::OperatorNodeBase * opr) { \ | |||||
auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \ | |||||
return indexing_param_to_json(indices); \ | |||||
} | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(Subtensor); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(SetSubtensor); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(IncrSubtensor); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingMultiAxisVec); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingSetMultiAxisVec); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingIncrMultiAxisVec); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(MeshIndexing); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(IncrMeshIndexing); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(SetMeshIndexing); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedIncrMeshIndexing); | |||||
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedSetMeshIndexing); | |||||
template <> | template <> | ||||
std::shared_ptr<json::Value> opr_param_json_func<opr::Subtensor>( | |||||
std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>( | |||||
cg::OperatorNodeBase * opr) { | cg::OperatorNodeBase * opr) { | ||||
auto desc = json::Array::make(); | auto desc = json::Array::make(); | ||||
auto indices = opr->cast_final_safe<opr::Subtensor>().index_desc(); | |||||
for (auto &index : indices){ | |||||
desc->add( | |||||
json::Object::make({ | |||||
{"axis", json::NumberInt::make(index.axis.get_raw())}, | |||||
{"begin", json::NumberInt::make(index.begin.node() != nullptr)}, | |||||
{"end", json::NumberInt::make(index.end.node() != nullptr)}, | |||||
{"step", json::NumberInt::make(index.step.node() != nullptr)}, | |||||
{"idx", json::NumberInt::make(index.idx.node() != nullptr)}, | |||||
})); | |||||
auto axis_param = opr->cast_final_safe<opr::Reshape>().param(); | |||||
if (axis_param.axis != axis_param.MAX_NDIM){ | |||||
return json::Object::make({ | |||||
{"axis", json::NumberInt::make(axis_param.axis)}, | |||||
}); | |||||
} else { | |||||
return json::Object::make(); | |||||
} | } | ||||
} | |||||
return desc; | |||||
template <> | |||||
std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>( | |||||
cg::OperatorNodeBase * opr) { | |||||
auto desc = json::Array::make(); | |||||
auto axis_param = opr->cast_final_safe<opr::GetVarShape>().param(); | |||||
if (axis_param.axis != axis_param.MAX_NDIM){ | |||||
return json::Object::make({ | |||||
{"axis", json::NumberInt::make(axis_param.axis)}, | |||||
}); | |||||
} else { | |||||
return json::Object::make(); | |||||
} | |||||
} | } | ||||
template <> | |||||
std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>( | |||||
cg::OperatorNodeBase * opr) { | |||||
auto nms_param = opr->cast_final_safe<opr::standalone::NMSKeep>().param(); | |||||
return json::Object::make({ | |||||
{"iou_thresh", json::Number::make(nms_param.iou_thresh)}, | |||||
{"max_output", json::Number::make(nms_param.max_output)}, | |||||
}); | |||||
} | |||||
#endif // MGB_ENABLE_JSON | #endif // MGB_ENABLE_JSON | ||||
} // namespace | } // namespace | ||||
@@ -632,6 +726,17 @@ void OprFootprint::init_all_footprints() { | |||||
add_single_param_json<opr::Dimshuffle>(); | add_single_param_json<opr::Dimshuffle>(); | ||||
add_single_param_json<opr::AxisAddRemove>(); | add_single_param_json<opr::AxisAddRemove>(); | ||||
add_single_param_json<opr::Subtensor>(); | add_single_param_json<opr::Subtensor>(); | ||||
add_single_param_json<opr::SetSubtensor>(); | |||||
add_single_param_json<opr::IncrSubtensor>(); | |||||
add_single_param_json<opr::IndexingMultiAxisVec>(); | |||||
add_single_param_json<opr::IndexingSetMultiAxisVec>(); | |||||
add_single_param_json<opr::IndexingIncrMultiAxisVec>(); | |||||
add_single_param_json<opr::MeshIndexing>(); | |||||
add_single_param_json<opr::SetMeshIndexing>(); | |||||
add_single_param_json<opr::IncrMeshIndexing>(); | |||||
add_single_param_json<opr::BatchedMeshIndexing>(); | |||||
add_single_param_json<opr::BatchedSetMeshIndexing>(); | |||||
add_single_param_json<opr::BatchedIncrMeshIndexing>(); | |||||
add_single_param_json<opr::Reduce>(); | add_single_param_json<opr::Reduce>(); | ||||
add_single_param_json<opr::LocalShareForward>(); | add_single_param_json<opr::LocalShareForward>(); | ||||
add_single_param_json<opr::LocalShareBackwardData>(); | add_single_param_json<opr::LocalShareBackwardData>(); | ||||
@@ -639,7 +744,31 @@ void OprFootprint::init_all_footprints() { | |||||
add_single_param_json<opr::DeformableConvForward>(); | add_single_param_json<opr::DeformableConvForward>(); | ||||
add_single_param_json<opr::DeformableConvBackwardFilter>(); | add_single_param_json<opr::DeformableConvBackwardFilter>(); | ||||
add_single_param_json<opr::DeformableConvBackwardData>(); | add_single_param_json<opr::DeformableConvBackwardData>(); | ||||
add_single_param_json<opr::DeformablePSROIPoolingForward>(); | |||||
add_single_param_json<opr::BatchConvBiasForward>(); | add_single_param_json<opr::BatchConvBiasForward>(); | ||||
add_single_param_json<opr::BatchNormForward>(); | |||||
add_single_param_json<opr::Reshape>(); | |||||
add_single_param_json<opr::GetVarShape>(); | |||||
add_single_param_json<opr::Argsort>(); | |||||
add_single_param_json<opr::Argmin>(); | |||||
add_single_param_json<opr::Argmax>(); | |||||
add_single_param_json<opr::ElemwiseMultiType>(); | |||||
add_single_param_json<opr::AdaptivePooling>(); | |||||
add_single_param_json<opr::ROIPooling>(); | |||||
add_single_param_json<opr::ROIAlign>(); | |||||
add_single_param_json<opr::WarpPerspective>(); | |||||
add_single_param_json<opr::Remap>(); | |||||
add_single_param_json<opr::Resize>(); | |||||
add_single_param_json<opr::IndexingOneHot>(); | |||||
add_single_param_json<opr::IndexingSetOneHot>(); | |||||
add_single_param_json<opr::WarpAffine>(); | |||||
add_single_param_json<opr::TopK>(); | |||||
add_single_param_json<opr::UniformRNG>(); | |||||
add_single_param_json<opr::GaussianRNG>(); | |||||
add_single_param_json<opr::Linspace>(); | |||||
add_single_param_json<opr::Eye>(); | |||||
add_single_param_json<opr::standalone::NMSKeep>(); | |||||
add_single_param_json<opr::CvtColor>(); | |||||
#endif | #endif | ||||
} | } | ||||