|
- # -*- coding: utf-8 -*-
- import collections
- import fnmatch
- import itertools
- import pickle
- import re
- from collections import OrderedDict
- from typing import Any, Dict, List, Optional, Sequence
-
- from ..core import _imperative_rt
- from ..core._imperative_rt import ComputingGraph, SerializationMetadata
- from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape
- from ..core.tensor import megbrain_graph as G
- from ..logger import get_logger
- from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
- from .network_node import (
- ConstOpBase,
- Host2DeviceCopy,
- ImmutableTensor,
- NetworkNode,
- OpNode,
- VarNode,
- str_to_mge_class,
- )
-
- logger = get_logger(__name__)
-
-
- class Network:
- def __init__(self):
- self.input_vars = [] # input var of graph
- self._orig_inputs = []
- self.output_vars = [] # output var of graph
- self._orig_outputs = []
- self.all_oprs_map = OrderedDict() # _imperative_rt.graph.VarNode.id: VarNode
- self.all_vars_map = (
- OrderedDict()
- ) # _imperative_rt.graph.OperatorNode.id: OpNode
- self.graph = ComputingGraph()
- self._metadata = None
-
- @property
- def metadata(self):
- r"""Load metadata as a dict."""
- if not self._metadata.is_valid:
- logger.info("metadata is not valid!")
- return None
- ret = dict()
- try:
- user_info = pickle.loads(self._metadata.user_info)
- except: # pylint: disable=bare-except
- logger.warning(
- "can't parse user info by pickle, so return the original bytes object!"
- )
- user_info = self._metadata.user_info
- ret["user_info"] = user_info
- ret["graph_modified"] = self._metadata.graph_modified
- ret["optimized_for_inference"] = self._metadata.optimized_for_inference
- if ret["optimized_for_inference"]:
- ret.update(G.deserialize_infer_option(self._metadata.optimize_options))
- return ret
-
- @classmethod
- def load(cls, model_path: str, outspec: List[str] = None):
- r"""Loads a computing graph as a Network object.
-
- Args:
- model_path: file path of mge model.
- outspec: only load the subgraph with outspec as its endpoints.
- """
- self = cls()
- ret = G.load_graph(model_path)
- outputs, self._metadata = ret.output_vars_list, ret.metadata
- if outspec is not None:
- output_spec = outspec.copy()
- all_vars = get_dep_vars(outputs) + outputs
- new_outputs = {}
- for i in all_vars:
- if i.name in output_spec:
- new_outputs[i.name] = i
- output_spec.remove(i.name)
- assert len(output_spec) == 0, "Can not find {} in this model".format(
- output_spec
- )
- outputs = [new_outputs[i] for i in outspec]
- self._orig_outputs = outputs
- for x in self._orig_outputs:
- self.output_vars.append(self._get_var(x))
- self.add_dep_oprs()
- for x in self._orig_inputs:
- self.input_vars.append(self._get_var(x))
-
- self.graph = self._orig_outputs[0].graph
- return self
-
- def _compile(self):
- self.all_oprs_map = {}
- self.all_vars_map = {}
- for opr in self.all_oprs:
- if isinstance(opr, (ConstOpBase, Host2DeviceCopy)):
- opr.compile(self.graph)
- else:
- opr.compile()
- if opr.name is not None:
- opr._opr.name = opr.name
- self.all_oprs_map[opr._opr.id] = opr
- for o in opr.outputs:
- self.all_vars_map[o.var.id] = o
-
- def optimize_for_inference(self, dest_vars, **kwargs):
- r"""Applies optimize_for_inference pass for operator graph.
-
- Args:
- dest_vars: list of output vars in the operator graph
-
- Keyword Arguments:
-
- * enable_io16xc32 --
- whether to use float16 for I/O between oprs and use
- float32 as internal computation precision. Note the output var would be
- changed to float16.
- * enable_ioc16 --
- whether to use float16 for both I/O and computation
- precision.
- * enable_hwcd4 --
- whether to use NHWCD4 data layout. This is faster on some
- OpenCL backend.
- * enable_nchw88 --
- whether to use NCHW88 data layout, currently
- used in X86 AVX backend.
- * enable_nchw44 --
- whether to use NCHW44 data layout, currently
- used in arm backend.
- * enable_nchw44_dot --
- whether to use NCHW44_dot data layout, currently
- used in armv8.2+dotprod backend.
- * enable_nchw4 --
- whether to use NCHW4 data layout, currently
- used in nvidia backend(based on cudnn).
- * enable_nchw32 --
- whether to use NCHW32 data layout, currently
- used in nvidia backend with tensorcore(based on cudnn).
- * enable_chwn4 --
- whether to use CHWN4 data layout, currently
- used in nvidia backend with tensorcore.
- * enable_nchw64 --
- whether to use NCHW64 data layout, used for fast int4
- support on Nvidia GPU.
- * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
- into one opr.
- * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
- input for inference on nvidia backend(this optimization pass will
- result in mismatch of the precision of output of training and
- inference
- * enable_fuse_grain: fuse grain will be enable by default to fuse grain operator to huge operator, you can disable it.
- )
- """
-
- if not isinstance(dest_vars, Sequence):
- dest_vars = [dest_vars]
- dest_vars = list(G.VarNode(var.var) for var in dest_vars)
- new_vars = G.optimize_for_inference(dest_vars, **kwargs)
- return list(self._get_var(var) for var in new_vars)
-
- def dump(
- self,
- file,
- *,
- keep_var_name: int = 1,
- keep_opr_name: bool = False,
- keep_param_name: bool = False,
- keep_opr_priority: bool = False,
- strip_info_file=None,
- append_json=False,
- optimize_for_inference=True,
- append=False,
- user_info: Any = None,
- enable_metadata=True,
- **kwargs
- ):
- r"""Serializes graph to file.
-
- Args:
- file: output file, could be file object or filename.
- append: whether output is appended to ``file``.
- Only works when ``file`` is str.
- keep_var_name: level for keeping variable names:
-
- * 0: none of the names are kept
- * 1: (default)keep names of output vars
- * 2: keep names of all (output and internal) vars
-
- keep_opr_name: whether to keep operator names.
- keep_param_name: whether to keep param names, so param values can be
- easily manipulated after loading model
- keep_opr_priority: whether to keep priority setting for operators
- strip_info_file: a string for path or a file handler. if is not None,
- then the dump information for code strip would be written to ``strip_info_file``
- append_json: will be check when `strip_info_file` is not None. if set
- true, the information for code strip will be append to strip_info_file.
- if set false, will rewrite strip_info_file
- optimize_for_inference: enbale optmizations,
- will skip all optimize options if this is False. Default: True
- user_info: any type object, which will be pickled to bytes.
- enable_metadata: whether to save metadata into output file.
-
- See more detials in :meth:`~.trace.dump`.
- """
-
- def _set_var_name(var):
- graph_var = G.VarNode(var.var)
- graph_var.name = var.name
- return graph_var
-
- self._compile()
- out = list(map(_set_var_name, self.output_vars))
-
- if kwargs.pop("arg_names", False):
- logger.warning(
- '"arg_names" is not supported in Network.dump, rename input vars directly'
- )
- if kwargs.pop("output_names", False):
- logger.warning(
- '"output_names" is not supported in Network.dump, rename output vars directly'
- )
- if optimize_for_inference:
- out, optimize_options = G.optimize_for_inference(out, **kwargs)
-
- metadata = SerializationMetadata()
- if enable_metadata:
- metadata.is_valid = True
- metadata.graph_modified = True
- metadata.user_info = pickle.dumps(user_info)
- if optimize_for_inference:
- metadata.optimize_options = optimize_options
-
- G.set_priority_to_id([o._node if isinstance(o, G.VarNode) else o for o in out])
- dump_content, dump_info = G.dump_graph(
- out,
- keep_var_name=keep_var_name,
- keep_opr_name=keep_opr_name,
- keep_param_name=keep_param_name,
- keep_opr_priority=keep_opr_priority,
- strip_info_file=strip_info_file,
- append_json=append_json,
- metadata=metadata,
- )
- if isinstance(file, str):
- permission = "wb" if append == False else "ab"
- file = open(file, permission)
- file.write(dump_content)
- return dump_info
-
- def make_const(self, data, name=None, device=None):
- r"""Makes an ImmutableTensor OpNode to provide a parameter for the network."""
- node = ImmutableTensor(data, name, device, self.graph)
- node.compile(self.graph)
- return node.outputs[0]
-
- def make_input_node(self, shape, dtype, name=None, device=None):
- r"""Makes a Host2DeviceCopy OpNode to provide an input varnode for the network."""
- node = Host2DeviceCopy(shape, dtype, name, device)
- node.compile(self.graph)
- return node.outputs[0]
-
- def add_output(self, *vars: VarNode):
- r"""Adds vars into the network output node list"""
- if not all([var.owner for var in vars]):
- self.add_dep_oprs(*vars)
- for var in vars:
- # use method 'is' instead of 'in' to avoid
- # compare VarNode use elemwise equal
- if not any(var is _ for _ in self.output_vars):
- self.output_vars.append(var)
-
- def remove_output(self, *vars: VarNode):
- r"""Removes vars from the network output node list"""
- for var in vars:
- # use list pop instead of remove to avoid
- # compare VarNode use elemwise equal
- is_removed = False
- for idx, out_var in enumerate(self.output_vars):
- if var is out_var:
- self.output_vars.pop(idx)
- is_removed = True
- if not is_removed:
- logger.warning(
- "Failed to remove {}({}). Please check whether "
- "this node is in the output list.".format(var.name, id(var))
- )
-
- def add_dep_oprs(self, *vars):
- if len(vars) == 0:
- vars = self.output_vars
-
- assert all(isinstance(var, VarNode) for var in vars), "Only support add VarNode"
-
- q = list(vars)
- while len(q) > 0:
- cur = q.pop(0)
- if cur.owner is not None:
- continue
- if cur.name is None:
- cur.name = cur.var.name
- self.all_vars_map[cur.var.id] = cur
- mge_opr = cur.var.owner
- if get_opr_type(mge_opr) == "Host2DeviceCopy":
- self._orig_inputs.extend(mge_opr.outputs)
- cur.owner = self._add_opr(mge_opr)
- if cur.owner is None:
- cur.owner = self.all_oprs_map[mge_opr.id]
- continue
- q.extend(cur.owner.inputs)
- return list(vars)
-
- def modify_opr_names(self, modifier):
- r"""Modifies names of operators **inplace**; useful for merging loaded
- network into another network
-
- Args:
- modifier(str or callable): a string to be prepended to the name, or a function
- that maps from name to name
- """
- if isinstance(modifier, str):
- om = modifier
- modifier = lambda v: "{}.{}".format(om, v)
- assert isinstance(modifier, collections.abc.Callable)
- for i in self.all_oprs:
- v0 = i.name
- v1 = modifier(v0)
- assert isinstance(v1, str)
- i.name = v1
-
- def reset_batch_size(self, batchsize, *, blacklist=()):
- r"""Helper for reset batch size; first dimension of all data providers
- not in blacklist are assumed to be the batch size
-
- Args:
- blacklist: data provider names whose first dimension is not
- batchbatch size
- """
- blacklist = set(blacklist)
- prev_batchsize = None
- for i in self.data_providers_filter:
- if i.name in blacklist:
- blacklist.remove(i.name)
- else:
- shp = list(i.shape)
- if prev_batchsize is None:
- prev_batchsize = shp[0]
- else:
- assert prev_batchsize == shp[0], (
- "batchsize mismatch: batchsize={} "
- "shape={} dp={}".format(prev_batchsize, shp, i.name)
- )
- shp[0] = batchsize
- i.shape = tuple(shp)
- self._compile()
- assert prev_batchsize is not None, "no data provider found"
- assert not blacklist, "unused items in blacklist: {}".format(blacklist)
-
- def replace_vars(self, repl_dict: Dict[VarNode, VarNode]):
- r"""Replaces vars in the graph.
-
- Args:
- repl_dict: the map {old_var: new_var} that specifies how to replace the vars.
- """
- if not all([var.owner for var in repl_dict.values()]):
- self.add_dep_oprs(*list(repl_dict.values()))
- for var in self.all_vars:
- if var in repl_dict:
- repl_var = repl_dict[var]
- if repl_var is var:
- continue
- for opnode in var.users:
- # use method 'is' instead of 'in' to avoid
- # compare VarNode use elemwise equal
- assert any([var is _ for _ in opnode.inputs])
- opnode.inputs = [repl_var if var is i else i for i in opnode.inputs]
- if opnode not in repl_var.users:
- repl_var.users.append(opnode)
- var.users.clear()
- self._compile()
-
- def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
- r"""Replaces operators in the graph.
-
- Args:
- repl_dict: the map {old_opr: new_opr} that specifies how to replace the operators.
- """
- for opr in self.all_oprs:
- if opr in repl_dict:
- assert len(opr.outputs) == len(
- repl_dict[opr].outputs
- ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr]))
- for ind, var in enumerate(opr.outputs):
- var.owner = repl_dict[opr]
- var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
- var._reset_var(repl_dict[opr].outputs[ind].var)
- repl_dict[opr].outputs = opr.outputs
- self._compile()
-
- def get_opr_by_type(self, oprcls, unique=True):
- assert issubclass(oprcls, OpNode)
- rst = self.opr_filter.type(oprcls).as_list()
- if unique:
- assert len(rst) == 1, "{} operators of type {} found".format(
- len(rst), oprcls
- )
- (rst,) = rst
- return rst
-
- def get_opr_by_name(self, name, unique=True):
- rst = self.opr_filter.name(name).as_list()
- if unique:
- assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name)
- (rst,) = rst
- return rst
-
- def get_var_by_name(self, name, unique=True):
- rst = self.var_filter.name(name).as_list()
- if unique:
- assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name)
- (rst,) = rst
- return rst
-
- def get_var_receive_oprs(self, var):
- r"""Gets all oprs which use var as input"""
- return self.opr_filter.has_input(var).as_list()
-
- def get_dep_oprs(self, var):
- r"""Gets dependent oprs of var"""
- return get_oprs_seq(var, False, False)
-
- @property
- def opr_filter(self):
- r"""Filter on all opnodes of the Network."""
- oprs = self.all_oprs
- return NodeFilter(itertools.islice(oprs, len(oprs)))
-
- @property
- def var_filter(self):
- r"""Filter on all varnode of the Network."""
- vars = self.all_vars
- return NodeFilter(itertools.islice(vars, len(vars)))
-
- @property
- def params_filter(self): # all immutable tensor
- r"""Filter on all parameters (ImmutableTensor Opr) of the Network"""
- return self.opr_filter.param_provider()
-
- @property
- def data_providers_filter(self): # all host2devicecopy
- r"""Filter on all input nodes (Host2DeviceCopy Opr) of the Network"""
- return self.opr_filter.data_provider()
-
- @property
- def dest_vars(self):
- r"""Output varnodes of the Network."""
- return self.output_vars
-
- @property
- def all_oprs(self):
- return get_oprs_seq(self.output_vars, False, False)
-
- @property
- def all_vars(self):
- return get_dep_vars(self.output_vars)
-
- @property
- def all_vars_dict(self):
- return self.var_filter.as_dict()
-
- @property
- def all_oprs_dict(self):
- return self.opr_filter.as_dict()
-
- def _add_opr(self, opr) -> Optional[OpNode]:
- r"""Used for loading and building graph."""
- assert isinstance(opr, _imperative_rt.graph.OperatorNode)
-
- # TODO: use megbrain C++ RTTI to replace type string
- if opr.id not in self.all_oprs_map:
- opnode = str_to_mge_class(get_opr_type(opr)).load(opr)
- self.all_oprs_map[opr.id] = opnode
- for var in opr.inputs:
- varnode = self._get_var(var)
- opnode.add_inp_var(varnode)
- varnode.users.append(opnode)
- for var in opr.outputs:
- opnode.add_out_var(self._get_var(var))
- return opnode
- else:
- # overwrite the opnode 'new' output VarNode with
- # original one when output number larger than 1,
- # or will cause dependence issue in _compiler step.
- if len(opr.outputs) > 1:
- opnode = self.all_oprs_map[opr.id]
- for idx, output in enumerate(opnode.outputs):
- if output.var.id in self.all_vars_map:
- opnode.outputs[idx] = self.all_vars_map[output.var.id]
-
- return None
-
- def _get_opr(self, x):
- if x.id in self.all_oprs_map:
- return self.all_oprs_map[x.id]
- else:
- return None
-
- def _get_var(self, x):
- r"""Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`."""
- assert isinstance(x, _imperative_rt.graph.VarNode)
- if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x:
- self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner))
- return self.all_vars_map[x.id]
-
-
- def set_symbolic_shape(option: bool):
- r"""Set the VarNode use symbolic shape or not, return the last status.
- Please set to True and must recover after dump if want to change the input batch size.
-
- Args:
- option: True for enable symbolic shape.
- """
- return _set_symbolic_shape(option)
-
-
- def as_varnode(obj):
- r"""convert a :class:`.utils.network_node.VarNode` compatible object to :class:`.utils.network_node.VarNode`.
-
- Args:
- obj: it must be one of the following:
-
- 1. a :class:`.utils.network_node.VarNode` object
- 2. a :class:`.utils.network_node.OpNode` object that has unique output
- 3. an iterable that produces either type 1 or 2, with length 1
-
- """
- if type(obj) is VarNode:
- return obj
-
- if isinstance(obj, OpNode):
- assert len(obj.outputs) == 1, (
- "operator {} must have one output to be converted to VarNode; "
- "got {} actually".format(obj, len(obj.outputs))
- )
- ret = obj.outputs[0]
- assert type(ret) is VarNode
- return ret
-
- assert isinstance(
- obj, collections.abc.Iterable
- ), "{} is not compatible with VarNode".format(obj)
-
- val = list(obj)
- assert (
- len(val) == 1
- ), "can not convert sequence of length {} to VarNode ({})".format(
- len(val), (lambda s: s if len(s) < 50 else s[:50] + " ...")(str(val))
- )
- return as_varnode(val[0])
-
-
- def as_oprnode(obj):
- r"""convert a :class:`.utils.network_node.OpNode` compatible object to
- :class:`.utils.network_node.OpNode`; it works like :func:`as_varnode`.
- """
- if type(obj) is VarNode:
- return obj.owner
-
- if isinstance(obj, OpNode):
- return obj
-
- assert isinstance(
- obj, collections.abc.Iterable
- ), "{} is not compatible with OpNode".format(obj)
-
- val = list(obj)
- assert (
- len(val) == 1
- ), "can not convert sequence of length {} to " "OpNode({})".format(len(val), val)
- return as_oprnode(val[0])
-
-
- class NodeFilter:
- r"""Filter on node iterator. This class is an iterator of
- :class:`.NetworkNode` objects and multiple filtering conditions and
- mappers can be chained.
-
- Example:
-
- .. code-block::
-
- # find all :class:`.ImmutableTensor` nodes
- for i in NodeFilter(node_iter).param_provider():
- print(i)
-
- # find all :class:`.ImmutableTensor` nodes that end with ':W'
- for i in NodeFilter(node_iter).param_provider().name('*:W'):
- print(i)
-
- # number of inputs
- nr_input = NodeFilter(node_iter).data_provider().as_count()
- """
-
- _iter = None
-
- def __init__(self, node_iter):
- """
- :param node_iter: iterator to :class:`.NetworkNode`, or a
- :class:`.VarNode`-compatible object; in the later case, its
- dependent oprs would be used
- """
- if isinstance(node_iter, VarNode):
- oprs = get_oprs_seq(node_iter, False, False)
- node_iter = itertools.islice(oprs, len(oprs) - 1)
- if isinstance(node_iter, OpNode):
- oprs = get_oprs_seq(node_iter.inputs, False, False)
- node_iter = itertools.islice(oprs, len(oprs) - 1)
-
- assert isinstance(node_iter, collections.abc.Iterable)
- if (not isinstance(node_iter, NodeFilter)) and type(
- self
- ) is not NodeFilterCheckType:
- node_iter = NodeFilterCheckType(node_iter, NetworkNode)
- self._iter = node_iter
-
- @classmethod
- def make_all_deps(cls, *dest_vars):
- r"""make a :class:`NodeFilter` that contains all deps of given vars"""
- return cls(list(get_oprs_seq(dest_vars, False, False)))
-
- def __iter__(self):
- r"""to be overwritten by subclass to implement filters"""
- return iter(self._iter)
-
- def type(self, node_type):
- r"""filter by specific node type
-
- Args:
- node_type: node type class
-
- Returns:
- a new :class:`NodeFilter` object
- """
- return NodeFilterType(self, node_type)
-
- def check_type(self, node_type):
- r"""assert that all oprs produced by this iterator are instances of
- certain type
-
- Args:
- node_type: node type class
-
- Returns:
- a new :class:`NodeFilter` object
-
- Raises:
- TypeError if type check failed
- """
- return NodeFilterCheckType(self, node_type)
-
- def not_type(self, node_type):
- r"""remove oprs of specific type
-
- Args:
- node_type: node type class
-
- Returns:
- a new :class:`NodeFilter` object
- """
- return NodeFilterNotType(self, node_type)
-
- def param_provider(self):
- r"""get :class:`~.ParamProvider` oprs; shorthand for
- ``.type(ParamProvider)``
- """
-
- return self.type(ImmutableTensor)
-
- def data_provider(self):
- r"""get :class:`.DataProvider` oprs; shorthand for
- ``.type(DataProvider)``
- """
-
- return self.type(Host2DeviceCopy)
-
- def name(self, pattern, ignorecase=True):
- r"""filter by node name
-
- Args:
- pattern(class:`str`): a string in glob syntax that can contain ``?`` and
- ``*`` to match a single or arbitrary characters.
- ignorecase(bool, optional): whether to ignroe case
-
- Returns:
- a new :class:`NodeFilter` object
- """
- return NodeFilterName(self, pattern, ignorecase)
-
- def has_input(self, var):
- r"""an opr is kept if it has given var as one of its inputs
-
- Args:
- var: var node to checked
-
- Returns:
- a new :class:`NodeFilter` object
- """
- return NodeFilterHasInput(self, var)
-
- def as_list(self):
- r"""consume this iterator and return its content as a list"""
- return list(self)
-
- def as_unique(self):
- r"""assert that this iterator yields only one node and return it
-
- Returns:
- class:`.GraphNodeBase`: the unique node
-
- Raises:
- ValueError if this iterator does not yield a unique node
- """
- (opr,) = self
- return opr
-
- def as_dict(self):
- r"""construct an ordered dict to map from node names to objects in
- this iterator
- """
- return collections.OrderedDict((i.name, i) for i in self)
-
- def as_count(self):
- r"""consume this iterator and get the number of elements"""
- return sum(1 for _ in self)
-
-
- class NodeFilterType(NodeFilter):
- r"""see :meth:`NodeFilter.type`"""
-
- _node_type = None
-
- def __init__(self, node_iter, node_type):
- assert issubclass(node_type, NetworkNode), "bad opr type: {}".format(node_type)
- super().__init__(node_iter)
- self._node_type = node_type
-
- def __iter__(self):
- for i in self._iter:
- if isinstance(i, self._node_type):
- yield i
-
-
- class NodeFilterNotType(NodeFilterType):
- r"""see :meth:`NodeFilter.not_type`"""
-
- def __iter__(self):
- for i in self._iter:
- if not isinstance(i, self._node_type):
- yield i
-
-
- class NodeFilterCheckType(NodeFilterType):
- r"""see :meth:`NodeFilter.check_type`"""
-
- def __iter__(self):
- for i in self._iter:
- if not isinstance(i, self._node_type):
- raise TypeError(
- "all nodes should be {}; got {!r}".format(self._node_type, i)
- )
- yield i
-
-
- class NodeFilterHasInput(NodeFilter):
- r"""see :meth:`NodeFilter.has_input`"""
-
- _var = None
-
- def __init__(self, node_iter, var):
- var = as_varnode(var)
- super().__init__(node_iter)
- self.var = var
-
- def __iter__(self):
- for i in self._iter:
- assert isinstance(
- i, OpNode
- ), "has_input() must be used with OpNode; " "got {!r}".format(i)
- if any(self.var is _ for _ in i.inputs):
- yield i
-
-
- class NodeFilterName(NodeFilter):
- r"""see :meth:`NodeFilter.name`"""
-
- _re = None
-
- def __init__(self, node_iter, pattern, ignorecase):
- super().__init__(node_iter)
- self.pattern = pattern
- self._re = self.make_re(pattern, ignorecase)
-
- @classmethod
- def make_re(cls, pattern, ignorecase=True):
- assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
- assert isinstance(ignorecase, bool)
- flags = 0
- if ignorecase:
- flags |= re.IGNORECASE
- return re.compile(fnmatch.translate(pattern), flags=flags)
-
- def __iter__(self):
- for i in self._iter:
- if self.pattern == i.name or self._re.match(i.name):
- yield i
|