You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817
  1. # -*- coding: utf-8 -*-
  2. import collections
  3. import fnmatch
  4. import itertools
  5. import pickle
  6. import re
  7. from collections import OrderedDict
  8. from typing import Any, Dict, List, Optional, Sequence
  9. from ..core import _imperative_rt
  10. from ..core._imperative_rt import ComputingGraph, SerializationMetadata
  11. from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape
  12. from ..core.tensor import megbrain_graph as G
  13. from ..logger import get_logger
  14. from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
  15. from .network_node import (
  16. ConstOpBase,
  17. Host2DeviceCopy,
  18. ImmutableTensor,
  19. NetworkNode,
  20. OpNode,
  21. VarNode,
  22. str_to_mge_class,
  23. )
  24. logger = get_logger(__name__)
  25. class Network:
  26. def __init__(self):
  27. self.input_vars = [] # input var of graph
  28. self._orig_inputs = []
  29. self.output_vars = [] # output var of graph
  30. self._orig_outputs = []
  31. self.all_oprs_map = OrderedDict() # _imperative_rt.graph.VarNode.id: VarNode
  32. self.all_vars_map = (
  33. OrderedDict()
  34. ) # _imperative_rt.graph.OperatorNode.id: OpNode
  35. self.graph = ComputingGraph()
  36. self._metadata = None
  37. @property
  38. def metadata(self):
  39. r"""Load metadata as a dict."""
  40. if not self._metadata.is_valid:
  41. logger.info("metadata is not valid!")
  42. return None
  43. ret = dict()
  44. try:
  45. user_info = pickle.loads(self._metadata.user_info)
  46. except: # pylint: disable=bare-except
  47. logger.warning(
  48. "can't parse user info by pickle, so return the original bytes object!"
  49. )
  50. user_info = self._metadata.user_info
  51. ret["user_info"] = user_info
  52. ret["graph_modified"] = self._metadata.graph_modified
  53. ret["optimized_for_inference"] = self._metadata.optimized_for_inference
  54. if ret["optimized_for_inference"]:
  55. ret.update(G.deserialize_infer_option(self._metadata.optimize_options))
  56. return ret
  57. @classmethod
  58. def load(cls, model_path: str, outspec: List[str] = None):
  59. r"""Loads a computing graph as a Network object.
  60. Args:
  61. model_path: file path of mge model.
  62. outspec: only load the subgraph with outspec as its endpoints.
  63. """
  64. self = cls()
  65. ret = G.load_graph(model_path)
  66. outputs, self._metadata = ret.output_vars_list, ret.metadata
  67. if outspec is not None:
  68. output_spec = outspec.copy()
  69. all_vars = get_dep_vars(outputs) + outputs
  70. new_outputs = {}
  71. for i in all_vars:
  72. if i.name in output_spec:
  73. new_outputs[i.name] = i
  74. output_spec.remove(i.name)
  75. assert len(output_spec) == 0, "Can not find {} in this model".format(
  76. output_spec
  77. )
  78. outputs = [new_outputs[i] for i in outspec]
  79. self._orig_outputs = outputs
  80. for x in self._orig_outputs:
  81. self.output_vars.append(self._get_var(x))
  82. self.add_dep_oprs()
  83. for x in self._orig_inputs:
  84. self.input_vars.append(self._get_var(x))
  85. self.graph = self._orig_outputs[0].graph
  86. return self
  87. def _compile(self):
  88. self.all_oprs_map = {}
  89. self.all_vars_map = {}
  90. for opr in self.all_oprs:
  91. if isinstance(opr, (ConstOpBase, Host2DeviceCopy)):
  92. opr.compile(self.graph)
  93. else:
  94. opr.compile()
  95. if opr.name is not None:
  96. opr._opr.name = opr.name
  97. self.all_oprs_map[opr._opr.id] = opr
  98. for o in opr.outputs:
  99. self.all_vars_map[o.var.id] = o
  100. def optimize_for_inference(self, dest_vars, **kwargs):
  101. r"""Applies optimize_for_inference pass for operator graph.
  102. Args:
  103. dest_vars: list of output vars in the operator graph
  104. Keyword Arguments:
  105. * enable_io16xc32 --
  106. whether to use float16 for I/O between oprs and use
  107. float32 as internal computation precision. Note the output var would be
  108. changed to float16.
  109. * enable_ioc16 --
  110. whether to use float16 for both I/O and computation
  111. precision.
  112. * enable_hwcd4 --
  113. whether to use NHWCD4 data layout. This is faster on some
  114. OpenCL backend.
  115. * enable_nchw88 --
  116. whether to use NCHW88 data layout, currently
  117. used in X86 AVX backend.
  118. * enable_nchw44 --
  119. whether to use NCHW44 data layout, currently
  120. used in arm backend.
  121. * enable_nchw44_dot --
  122. whether to use NCHW44_dot data layout, currently
  123. used in armv8.2+dotprod backend.
  124. * enable_nchw4 --
  125. whether to use NCHW4 data layout, currently
  126. used in nvidia backend(based on cudnn).
  127. * enable_nchw32 --
  128. whether to use NCHW32 data layout, currently
  129. used in nvidia backend with tensorcore(based on cudnn).
  130. * enable_chwn4 --
  131. whether to use CHWN4 data layout, currently
  132. used in nvidia backend with tensorcore.
  133. * enable_nchw64 --
  134. whether to use NCHW64 data layout, used for fast int4
  135. support on Nvidia GPU.
  136. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  137. into one opr.
  138. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  139. input for inference on nvidia backend(this optimization pass will
  140. result in mismatch of the precision of output of training and
  141. inference)
  142. """
  143. if not isinstance(dest_vars, Sequence):
  144. dest_vars = [dest_vars]
  145. dest_vars = list(G.VarNode(var.var) for var in dest_vars)
  146. new_vars = G.optimize_for_inference(dest_vars, **kwargs)
  147. return list(self._get_var(var) for var in new_vars)
  148. def dump(
  149. self,
  150. file,
  151. *,
  152. keep_var_name: int = 1,
  153. keep_opr_name: bool = False,
  154. keep_param_name: bool = False,
  155. keep_opr_priority: bool = False,
  156. strip_info_file=None,
  157. append_json=False,
  158. optimize_for_inference=True,
  159. append=False,
  160. user_info: Any = None,
  161. enable_metadata=True,
  162. **kwargs
  163. ):
  164. r"""Serializes graph to file.
  165. Args:
  166. file: output file, could be file object or filename.
  167. append: whether output is appended to ``file``.
  168. Only works when ``file`` is str.
  169. keep_var_name: level for keeping variable names:
  170. * 0: none of the names are kept
  171. * 1: (default)keep names of output vars
  172. * 2: keep names of all (output and internal) vars
  173. keep_opr_name: whether to keep operator names.
  174. keep_param_name: whether to keep param names, so param values can be
  175. easily manipulated after loading model
  176. keep_opr_priority: whether to keep priority setting for operators
  177. strip_info_file: a string for path or a file handler. if is not None,
  178. then the dump information for code strip would be written to ``strip_info_file``
  179. append_json: will be check when `strip_info_file` is not None. if set
  180. true, the information for code strip will be append to strip_info_file.
  181. if set false, will rewrite strip_info_file
  182. optimize_for_inference: enbale optmizations,
  183. will skip all optimize options if this is False. Default: True
  184. user_info: any type object, which will be pickled to bytes.
  185. enable_metadata: whether to save metadata into output file.
  186. See more detials in :meth:`~.trace.dump`.
  187. """
  188. def _set_var_name(var):
  189. graph_var = G.VarNode(var.var)
  190. graph_var.name = var.name
  191. return graph_var
  192. self._compile()
  193. out = list(map(_set_var_name, self.output_vars))
  194. if kwargs.pop("arg_names", False):
  195. logger.warning(
  196. '"arg_names" is not supported in Network.dump, rename input vars directly'
  197. )
  198. if kwargs.pop("output_names", False):
  199. logger.warning(
  200. '"output_names" is not supported in Network.dump, rename output vars directly'
  201. )
  202. if optimize_for_inference:
  203. out, optimize_options = G.optimize_for_inference(out, **kwargs)
  204. metadata = SerializationMetadata()
  205. if enable_metadata:
  206. metadata.is_valid = True
  207. metadata.graph_modified = True
  208. metadata.user_info = pickle.dumps(user_info)
  209. if optimize_for_inference:
  210. metadata.optimize_options = optimize_options
  211. G.set_priority_to_id([o._node if isinstance(o, G.VarNode) else o for o in out])
  212. dump_content, dump_info = G.dump_graph(
  213. out,
  214. keep_var_name=keep_var_name,
  215. keep_opr_name=keep_opr_name,
  216. keep_param_name=keep_param_name,
  217. keep_opr_priority=keep_opr_priority,
  218. strip_info_file=strip_info_file,
  219. append_json=append_json,
  220. metadata=metadata,
  221. )
  222. if isinstance(file, str):
  223. permission = "wb" if append == False else "ab"
  224. file = open(file, permission)
  225. file.write(dump_content)
  226. return dump_info
  227. def make_const(self, data, name=None, device=None):
  228. r"""Makes an ImmutableTensor OpNode to provide a parameter for the network."""
  229. node = ImmutableTensor(data, name, device, self.graph)
  230. node.compile(self.graph)
  231. return node.outputs[0]
  232. def make_input_node(self, shape, dtype, name=None, device=None):
  233. r"""Makes a Host2DeviceCopy OpNode to provide an input varnode for the network."""
  234. node = Host2DeviceCopy(shape, dtype, name, device)
  235. node.compile(self.graph)
  236. return node.outputs[0]
  237. def add_output(self, *vars: VarNode):
  238. r"""Adds vars into the network output node list"""
  239. if not all([var.owner for var in vars]):
  240. self.add_dep_oprs(*vars)
  241. for var in vars:
  242. # use method 'is' instead of 'in' to avoid
  243. # compare VarNode use elemwise equal
  244. if not any(var is _ for _ in self.output_vars):
  245. self.output_vars.append(var)
  246. def remove_output(self, *vars: VarNode):
  247. r"""Removes vars from the network output node list"""
  248. for var in vars:
  249. # use list pop instead of remove to avoid
  250. # compare VarNode use elemwise equal
  251. is_removed = False
  252. for idx, out_var in enumerate(self.output_vars):
  253. if var is out_var:
  254. self.output_vars.pop(idx)
  255. is_removed = True
  256. if not is_removed:
  257. logger.warning(
  258. "Failed to remove {}({}). Please check whether "
  259. "this node is in the output list.".format(var.name, id(var))
  260. )
  261. def add_dep_oprs(self, *vars):
  262. if len(vars) == 0:
  263. vars = self.output_vars
  264. assert all(isinstance(var, VarNode) for var in vars), "Only support add VarNode"
  265. q = list(vars)
  266. while len(q) > 0:
  267. cur = q.pop(0)
  268. if cur.owner is not None:
  269. continue
  270. if cur.name is None:
  271. cur.name = cur.var.name
  272. self.all_vars_map[cur.var.id] = cur
  273. mge_opr = cur.var.owner
  274. if get_opr_type(mge_opr) == "Host2DeviceCopy":
  275. self._orig_inputs.extend(mge_opr.outputs)
  276. cur.owner = self._add_opr(mge_opr)
  277. if cur.owner is None:
  278. cur.owner = self.all_oprs_map[mge_opr.id]
  279. continue
  280. q.extend(cur.owner.inputs)
  281. return list(vars)
  282. def modify_opr_names(self, modifier):
  283. r"""Modifies names of operators **inplace**; useful for merging loaded
  284. network into another network
  285. Args:
  286. modifier(str or callable): a string to be prepended to the name, or a function
  287. that maps from name to name
  288. """
  289. if isinstance(modifier, str):
  290. om = modifier
  291. modifier = lambda v: "{}.{}".format(om, v)
  292. assert isinstance(modifier, collections.Callable)
  293. for i in self.all_oprs:
  294. v0 = i.name
  295. v1 = modifier(v0)
  296. assert isinstance(v1, str)
  297. i.name = v1
  298. def reset_batch_size(self, batchsize, *, blacklist=()):
  299. r"""Helper for reset batch size; first dimension of all data providers
  300. not in blacklist are assumed to be the batch size
  301. Args:
  302. blacklist: data provider names whose first dimension is not
  303. batchbatch size
  304. """
  305. blacklist = set(blacklist)
  306. prev_batchsize = None
  307. for i in self.data_providers_filter:
  308. if i.name in blacklist:
  309. blacklist.remove(i.name)
  310. else:
  311. shp = list(i.shape)
  312. if prev_batchsize is None:
  313. prev_batchsize = shp[0]
  314. else:
  315. assert prev_batchsize == shp[0], (
  316. "batchsize mismatch: batchsize={} "
  317. "shape={} dp={}".format(prev_batchsize, shp, i.name)
  318. )
  319. shp[0] = batchsize
  320. i.shape = tuple(shp)
  321. self._compile()
  322. assert prev_batchsize is not None, "no data provider found"
  323. assert not blacklist, "unused items in blacklist: {}".format(blacklist)
  324. def replace_vars(self, repl_dict: Dict[VarNode, VarNode]):
  325. r"""Replaces vars in the graph.
  326. Args:
  327. repl_dict: the map {old_var: new_var} that specifies how to replace the vars.
  328. """
  329. if not all([var.owner for var in repl_dict.values()]):
  330. self.add_dep_oprs(*list(repl_dict.values()))
  331. for var in self.all_vars:
  332. if var in repl_dict:
  333. repl_var = repl_dict[var]
  334. if repl_var is var:
  335. continue
  336. for opnode in var.users:
  337. # use method 'is' instead of 'in' to avoid
  338. # compare VarNode use elemwise equal
  339. assert any([var is _ for _ in opnode.inputs])
  340. opnode.inputs = [repl_var if var is i else i for i in opnode.inputs]
  341. if opnode not in repl_var.users:
  342. repl_var.users.append(opnode)
  343. var.users.clear()
  344. self._compile()
  345. def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
  346. r"""Replaces operators in the graph.
  347. Args:
  348. repl_dict: the map {old_opr: new_opr} that specifies how to replace the operators.
  349. """
  350. for opr in self.all_oprs:
  351. if opr in repl_dict:
  352. assert len(opr.outputs) == len(
  353. repl_dict[opr].outputs
  354. ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr]))
  355. for ind, var in enumerate(opr.outputs):
  356. var.owner = repl_dict[opr]
  357. var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
  358. var._reset_var(repl_dict[opr].outputs[ind].var)
  359. repl_dict[opr].outputs = opr.outputs
  360. self._compile()
  361. def get_opr_by_type(self, oprcls, unique=True):
  362. assert issubclass(oprcls, OpNode)
  363. rst = self.opr_filter.type(oprcls).as_list()
  364. if unique:
  365. assert len(rst) == 1, "{} operators of type {} found".format(
  366. len(rst), oprcls
  367. )
  368. (rst,) = rst
  369. return rst
  370. def get_opr_by_name(self, name, unique=True):
  371. rst = self.opr_filter.name(name).as_list()
  372. if unique:
  373. assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name)
  374. (rst,) = rst
  375. return rst
  376. def get_var_by_name(self, name, unique=True):
  377. rst = self.var_filter.name(name).as_list()
  378. if unique:
  379. assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name)
  380. (rst,) = rst
  381. return rst
  382. def get_var_receive_oprs(self, var):
  383. r"""Gets all oprs which use var as input"""
  384. return self.opr_filter.has_input(var).as_list()
  385. def get_dep_oprs(self, var):
  386. r"""Gets dependent oprs of var"""
  387. return get_oprs_seq(var, False, False)
  388. @property
  389. def opr_filter(self):
  390. r"""Filter on all opnodes of the Network."""
  391. oprs = self.all_oprs
  392. return NodeFilter(itertools.islice(oprs, len(oprs)))
  393. @property
  394. def var_filter(self):
  395. r"""Filter on all varnode of the Network."""
  396. vars = self.all_vars
  397. return NodeFilter(itertools.islice(vars, len(vars)))
  398. @property
  399. def params_filter(self): # all immutable tensor
  400. r"""Filter on all parameters (ImmutableTensor Opr) of the Network"""
  401. return self.opr_filter.param_provider()
  402. @property
  403. def data_providers_filter(self): # all host2devicecopy
  404. r"""Filter on all input nodes (Host2DeviceCopy Opr) of the Network"""
  405. return self.opr_filter.data_provider()
  406. @property
  407. def dest_vars(self):
  408. r"""Output varnodes of the Network."""
  409. return self.output_vars
  410. @property
  411. def all_oprs(self):
  412. return get_oprs_seq(self.output_vars, False, False)
  413. @property
  414. def all_vars(self):
  415. return get_dep_vars(self.output_vars)
  416. @property
  417. def all_vars_dict(self):
  418. return self.var_filter.as_dict()
  419. @property
  420. def all_oprs_dict(self):
  421. return self.opr_filter.as_dict()
  422. def _add_opr(self, opr) -> Optional[OpNode]:
  423. r"""Used for loading and building graph."""
  424. assert isinstance(opr, _imperative_rt.graph.OperatorNode)
  425. # TODO: use megbrain C++ RTTI to replace type string
  426. if opr.id not in self.all_oprs_map:
  427. opnode = str_to_mge_class(get_opr_type(opr)).load(opr)
  428. self.all_oprs_map[opr.id] = opnode
  429. for var in opr.inputs:
  430. varnode = self._get_var(var)
  431. opnode.add_inp_var(varnode)
  432. varnode.users.append(opnode)
  433. for var in opr.outputs:
  434. opnode.add_out_var(self._get_var(var))
  435. return opnode
  436. else:
  437. # overwrite the opnode 'new' output VarNode with
  438. # original one when output number larger than 1,
  439. # or will cause dependence issue in _compiler step.
  440. if len(opr.outputs) > 1:
  441. opnode = self.all_oprs_map[opr.id]
  442. for idx, output in enumerate(opnode.outputs):
  443. if output.var.id in self.all_vars_map:
  444. opnode.outputs[idx] = self.all_vars_map[output.var.id]
  445. return None
  446. def _get_opr(self, x):
  447. if x.id in self.all_oprs_map:
  448. return self.all_oprs_map[x.id]
  449. else:
  450. return None
  451. def _get_var(self, x):
  452. r"""Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`."""
  453. assert isinstance(x, _imperative_rt.graph.VarNode)
  454. if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x:
  455. self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner))
  456. return self.all_vars_map[x.id]
  457. def set_symbolic_shape(option: bool):
  458. r"""Set the VarNode use symbolic shape or not, return the last status.
  459. Please set to True and must recover after dump if want to change the input batch size.
  460. Args:
  461. option: True for enable symbolic shape.
  462. """
  463. return _set_symbolic_shape(option)
  464. def as_varnode(obj):
  465. r"""convert a :class:`.utils.network_node.VarNode` compatible object to :class:`.utils.network_node.VarNode`.
  466. Args:
  467. obj: it must be one of the following:
  468. 1. a :class:`.utils.network_node.VarNode` object
  469. 2. a :class:`.utils.network_node.OpNode` object that has unique output
  470. 3. an iterable that produces either type 1 or 2, with length 1
  471. """
  472. if type(obj) is VarNode:
  473. return obj
  474. if isinstance(obj, OpNode):
  475. assert len(obj.outputs) == 1, (
  476. "operator {} must have one output to be converted to VarNode; "
  477. "got {} actually".format(obj, len(obj.outputs))
  478. )
  479. ret = obj.outputs[0]
  480. assert type(ret) is VarNode
  481. return ret
  482. assert isinstance(
  483. obj, collections.Iterable
  484. ), "{} is not compatible with VarNode".format(obj)
  485. val = list(obj)
  486. assert (
  487. len(val) == 1
  488. ), "can not convert sequence of length {} to VarNode ({})".format(
  489. len(val), (lambda s: s if len(s) < 50 else s[:50] + " ...")(str(val))
  490. )
  491. return as_varnode(val[0])
  492. def as_oprnode(obj):
  493. r"""convert a :class:`.utils.network_node.OpNode` compatible object to
  494. :class:`.utils.network_node.OpNode`; it works like :func:`as_varnode`.
  495. """
  496. if type(obj) is VarNode:
  497. return obj.owner
  498. if isinstance(obj, OpNode):
  499. return obj
  500. assert isinstance(
  501. obj, collections.Iterable
  502. ), "{} is not compatible with OpNode".format(obj)
  503. val = list(obj)
  504. assert (
  505. len(val) == 1
  506. ), "can not convert sequence of length {} to " "OpNode({})".format(len(val), val)
  507. return as_oprnode(val[0])
  508. class NodeFilter:
  509. r"""Filter on node iterator. This class is an iterator of
  510. :class:`.NetworkNode` objects and multiple filtering conditions and
  511. mappers can be chained.
  512. Example:
  513. .. code-block::
  514. # find all :class:`.ImmutableTensor` nodes
  515. for i in NodeFilter(node_iter).param_provider():
  516. print(i)
  517. # find all :class:`.ImmutableTensor` nodes that end with ':W'
  518. for i in NodeFilter(node_iter).param_provider().name('*:W'):
  519. print(i)
  520. # number of inputs
  521. nr_input = NodeFilter(node_iter).data_provider().as_count()
  522. """
  523. _iter = None
  524. def __init__(self, node_iter):
  525. """
  526. :param node_iter: iterator to :class:`.NetworkNode`, or a
  527. :class:`.VarNode`-compatible object; in the later case, its
  528. dependent oprs would be used
  529. """
  530. if isinstance(node_iter, VarNode):
  531. oprs = get_oprs_seq(node_iter, False, False)
  532. node_iter = itertools.islice(oprs, len(oprs) - 1)
  533. if isinstance(node_iter, OpNode):
  534. oprs = get_oprs_seq(node_iter.inputs, False, False)
  535. node_iter = itertools.islice(oprs, len(oprs) - 1)
  536. assert isinstance(node_iter, collections.Iterable)
  537. if (not isinstance(node_iter, NodeFilter)) and type(
  538. self
  539. ) is not NodeFilterCheckType:
  540. node_iter = NodeFilterCheckType(node_iter, NetworkNode)
  541. self._iter = node_iter
  542. @classmethod
  543. def make_all_deps(cls, *dest_vars):
  544. r"""make a :class:`NodeFilter` that contains all deps of given vars"""
  545. return cls(list(get_oprs_seq(dest_vars, False, False)))
  546. def __iter__(self):
  547. r"""to be overwritten by subclass to implement filters"""
  548. return iter(self._iter)
  549. def type(self, node_type):
  550. r"""filter by specific node type
  551. Args:
  552. node_type: node type class
  553. Returns:
  554. a new :class:`NodeFilter` object
  555. """
  556. return NodeFilterType(self, node_type)
  557. def check_type(self, node_type):
  558. r"""assert that all oprs produced by this iterator are instances of
  559. certain type
  560. Args:
  561. node_type: node type class
  562. Returns:
  563. a new :class:`NodeFilter` object
  564. Raises:
  565. TypeError if type check failed
  566. """
  567. return NodeFilterCheckType(self, node_type)
  568. def not_type(self, node_type):
  569. r"""remove oprs of specific type
  570. Args:
  571. node_type: node type class
  572. Returns:
  573. a new :class:`NodeFilter` object
  574. """
  575. return NodeFilterNotType(self, node_type)
  576. def param_provider(self):
  577. r"""get :class:`~.ParamProvider` oprs; shorthand for
  578. ``.type(ParamProvider)``
  579. """
  580. return self.type(ImmutableTensor)
  581. def data_provider(self):
  582. r"""get :class:`.DataProvider` oprs; shorthand for
  583. ``.type(DataProvider)``
  584. """
  585. return self.type(Host2DeviceCopy)
  586. def name(self, pattern, ignorecase=True):
  587. r"""filter by node name
  588. Args:
  589. pattern(class:`str`): a string in glob syntax that can contain ``?`` and
  590. ``*`` to match a single or arbitrary characters.
  591. ignorecase(bool, optional): whether to ignroe case
  592. Returns:
  593. a new :class:`NodeFilter` object
  594. """
  595. return NodeFilterName(self, pattern, ignorecase)
  596. def has_input(self, var):
  597. r"""an opr is kept if it has given var as one of its inputs
  598. Args:
  599. var: var node to checked
  600. Returns:
  601. a new :class:`NodeFilter` object
  602. """
  603. return NodeFilterHasInput(self, var)
  604. def as_list(self):
  605. r"""consume this iterator and return its content as a list"""
  606. return list(self)
  607. def as_unique(self):
  608. r"""assert that this iterator yields only one node and return it
  609. Returns:
  610. class:`.GraphNodeBase`: the unique node
  611. Raises:
  612. ValueError if this iterator does not yield a unique node
  613. """
  614. (opr,) = self
  615. return opr
  616. def as_dict(self):
  617. r"""construct an ordered dict to map from node names to objects in
  618. this iterator
  619. """
  620. return collections.OrderedDict((i.name, i) for i in self)
  621. def as_count(self):
  622. r"""consume this iterator and get the number of elements"""
  623. return sum(1 for _ in self)
  624. class NodeFilterType(NodeFilter):
  625. r"""see :meth:`NodeFilter.type`"""
  626. _node_type = None
  627. def __init__(self, node_iter, node_type):
  628. assert issubclass(node_type, NetworkNode), "bad opr type: {}".format(node_type)
  629. super().__init__(node_iter)
  630. self._node_type = node_type
  631. def __iter__(self):
  632. for i in self._iter:
  633. if isinstance(i, self._node_type):
  634. yield i
  635. class NodeFilterNotType(NodeFilterType):
  636. r"""see :meth:`NodeFilter.not_type`"""
  637. def __iter__(self):
  638. for i in self._iter:
  639. if not isinstance(i, self._node_type):
  640. yield i
  641. class NodeFilterCheckType(NodeFilterType):
  642. r"""see :meth:`NodeFilter.check_type`"""
  643. def __iter__(self):
  644. for i in self._iter:
  645. if not isinstance(i, self._node_type):
  646. raise TypeError(
  647. "all nodes should be {}; got {!r}".format(self._node_type, i)
  648. )
  649. yield i
  650. class NodeFilterHasInput(NodeFilter):
  651. r"""see :meth:`NodeFilter.has_input`"""
  652. _var = None
  653. def __init__(self, node_iter, var):
  654. var = as_varnode(var)
  655. super().__init__(node_iter)
  656. self.var = var
  657. def __iter__(self):
  658. for i in self._iter:
  659. assert isinstance(
  660. i, OpNode
  661. ), "has_input() must be used with OpNode; " "got {!r}".format(i)
  662. if any(self.var is _ for _ in i.inputs):
  663. yield i
  664. class NodeFilterName(NodeFilter):
  665. r"""see :meth:`NodeFilter.name`"""
  666. _re = None
  667. def __init__(self, node_iter, pattern, ignorecase):
  668. super().__init__(node_iter)
  669. self.pattern = pattern
  670. self._re = self.make_re(pattern, ignorecase)
  671. @classmethod
  672. def make_re(cls, pattern, ignorecase=True):
  673. assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
  674. assert isinstance(ignorecase, bool)
  675. flags = 0
  676. if ignorecase:
  677. flags |= re.IGNORECASE
  678. return re.compile(fnmatch.translate(pattern), flags=flags)
  679. def __iter__(self):
  680. for i in self._iter:
  681. if self.pattern == i.name or self._re.match(i.name):
  682. yield i