You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 44 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import contextlib
  11. import functools
  12. import itertools
  13. import json
  14. import os
  15. import typing
  16. import warnings
  17. import weakref
  18. import numpy as np
  19. from ..core._imperative_rt import GraphProfiler, common
  20. from ..core._imperative_rt.core2 import Tensor as RawTensor
  21. from ..core._imperative_rt.core2 import (
  22. TensorWeakRef,
  23. apply,
  24. set_compiled,
  25. set_tracing,
  26. skip_tracing,
  27. unset_compiled,
  28. unset_tracing,
  29. )
  30. from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend
  31. from ..core._trace_option import set_symbolic_shape
  32. from ..core._wrap import device as as_device
  33. from ..core.ops.builtin import BackwardGraph, OpDef
  34. from ..core.ops.special import Const
  35. from ..core.tensor import megbrain_graph as G
  36. from ..core.tensor.utils import setscalar
  37. from ..utils.naming import auto_naming
  38. from .sublinear_memory_config import SublinearMemoryConfig
  39. def _input_node_use_static_shape():
  40. return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
  41. class TraceMismatchError(RuntimeError):
  42. pass
  43. active_trace = None
  44. def is_tracing():
  45. if active_trace is None:
  46. return False
  47. else:
  48. return not skip_tracing
  49. @contextlib.contextmanager
  50. def exclude_from_trace():
  51. global skip_tracing
  52. if skip_tracing:
  53. yield
  54. return
  55. try:
  56. skip_tracing = True
  57. unset_tracing()
  58. if active_trace is not None:
  59. active_trace._begin_excluded_region()
  60. yield
  61. finally:
  62. skip_tracing = False
  63. set_tracing()
  64. class TensorInfo:
  65. __slots__ = (
  66. # collected attributes
  67. "name",
  68. "external",
  69. "data_read",
  70. "shape_read",
  71. "value_read",
  72. "exported",
  73. "device",
  74. "dtype",
  75. "shape",
  76. "is_const",
  77. "bound_data",
  78. # resources for execution
  79. "varnode",
  80. "data_setter",
  81. "shape_reader",
  82. "value_reader",
  83. "data_reader",
  84. )
  85. def __init__(self):
  86. self.name = None
  87. self.exported = None
  88. self.data_read = None
  89. self.shape_read = None
  90. self.value_read = None
  91. self.bound_data = None
  92. self.data_setter = None
  93. self.shape_reader = None
  94. self.value_reader = None
  95. self.data_reader = None
  96. _io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}
  97. class trace:
  98. """
  99. Wraps a callable and provide:
  100. * tracing via :meth:`.trace` and :meth:`.dump`
  101. * accelerated evalutaion via :meth:`.__call__`
  102. :param function: the function will be traced.
  103. :param symbolic: whether to apply symbolic execution for tracing. Default: False
  104. :param capture_as_const: capture global vars or closures as const value. Default: False
  105. :param sublinear_memory_config: configuration for sublinear memory optimization.
  106. If not None, it enables sublinear memory optimization with given setting.
  107. :param profiling: whether to profile compiled trace. Default: False
  108. :param opt_level: optimization level for compiling trace.
  109. :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
  110. """
  111. def __new__(cls, *args, **kwargs):
  112. if not args:
  113. return functools.partial(cls, **kwargs)
  114. return super().__new__(cls)
  115. def __init__(
  116. self,
  117. function,
  118. symbolic=False,
  119. capture_as_const=False,
  120. sublinear_memory_config: SublinearMemoryConfig = None,
  121. profiling: bool = False,
  122. opt_level: int = None,
  123. symbolic_shape: bool = True,
  124. ):
  125. self.__wrapped__ = function
  126. self._symbolic = symbolic
  127. self._capture_as_const = capture_as_const
  128. self._sublinear_memory_config = sublinear_memory_config
  129. self._profiling = profiling
  130. self._profiler = None
  131. self._graph_opt_level = opt_level
  132. self._symbolic_shape = symbolic_shape
  133. self._output_handles = set()
  134. self._reset()
  135. def _reset(self):
  136. self._untraced = True
  137. self._tinfo = [] # handle -> TensorInfo
  138. self._seq = []
  139. self._pc = 0
  140. self._graph = None
  141. self._need_reset_nodes = None
  142. self._lazy_eval_graph = None
  143. self._lazy_eval_tensors = {}
  144. self._lazy_eval_links = None
  145. self._active_tensors = {}
  146. self._tensor_remaps = None
  147. self._inputs_to_restore = None
  148. self._arg_bindings = None
  149. self._kwarg_bindings = None
  150. self._output_bindings = None
  151. self._output_names = None
  152. def _new_handle(self):
  153. handle = len(self._tinfo)
  154. info = TensorInfo()
  155. self._tinfo.append(info)
  156. return handle, info
  157. def _apply_op(self, op, args):
  158. assert not self._untraced
  159. # check against trace
  160. if self._pc >= len(self._seq):
  161. raise TraceMismatchError("trace should end here, but more op observed")
  162. record = self._seq[self._pc]
  163. op_, ihandles, ohandles = record
  164. if (isinstance(op_, str) and op_ == "Const") or (op != op_):
  165. raise TraceMismatchError("op different from last time")
  166. if len(ihandles) != len(args):
  167. raise TraceMismatchError("op input size different from last time")
  168. # check all inputs of crrent op
  169. for h, x in zip(ihandles, args):
  170. info = self._tinfo[h]
  171. if info.external:
  172. if (
  173. x._compiled_info is not None
  174. and not self._tinfo[x._mixin_handle].exported
  175. ):
  176. raise TraceMismatchError(
  177. "failed to capture: input was an external tensor "
  178. "last time, got an internal tensor this time"
  179. )
  180. if info.bound_data:
  181. if x._compiled_info is not None:
  182. raise TraceMismatchError(
  183. "const capture violated: was an external tensor "
  184. "last time, got an internal tensor this time"
  185. )
  186. if x._handle != info.bound_data._handle:
  187. if not np.array_equal(x.numpy(), info.bound_data.numpy()):
  188. raise TraceMismatchError(
  189. "const capture violated: got "
  190. "a different tensor this time"
  191. )
  192. else:
  193. if info.dtype != x.dtype:
  194. raise TraceMismatchError(
  195. "failed to capture: different dtype from last time"
  196. )
  197. if info.device != x.device:
  198. raise TraceMismatchError(
  199. "failed to capture: different device from last time"
  200. )
  201. info.data_setter.set_value(x._dev_tensor())
  202. else:
  203. if x._mixin_handle == -1:
  204. if x._handle not in self._tensor_remaps:
  205. raise TraceMismatchError(
  206. "unexpected capture: trying to use an external tensor as "
  207. "input, but that input was an internal tensor last time"
  208. )
  209. else:
  210. x._mixin_handle = self._tensor_remaps[
  211. x._handle
  212. ]._CompiledTensorProxy__handle
  213. if x._mixin_handle != h:
  214. raise TraceMismatchError(
  215. "mis-wiring: input edge to an data flow "
  216. "graph node is different from last time"
  217. )
  218. self._pc += 1
  219. outputs = []
  220. for h in ohandles:
  221. info = self._tinfo[h]
  222. # generate output tensor and create compied info
  223. y = RawTensor(info.varnode)
  224. y._compiled_info = CompiledTensorProxy(h)
  225. y._mixin_handle = h
  226. outputs += [y]
  227. self._active_tensors[h] = TensorWeakRef(y)
  228. self._output_handles.update(ohandles)
  229. return outputs
  230. def _apply_const(self, value, dtype, device):
  231. assert not self._untraced
  232. # check against trace
  233. if self._pc >= len(self._seq):
  234. raise TraceMismatchError("trace should end here, but more op observed")
  235. record = self._seq[self._pc]
  236. op_, ihandles, ohandles = record
  237. # Const op is represented by a str
  238. assert isinstance(op_, str) and op_ == "Const"
  239. eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
  240. if not eq:
  241. raise TraceMismatchError(
  242. "const tensor violated: got a different tensor this time"
  243. )
  244. self._pc += 1
  245. (h,) = ohandles
  246. outputs = [self._tinfo[h].bound_data]
  247. return outputs
  248. # run in first step, record information for trace
  249. def _record_op(self, op, inputs, outputs):
  250. if skip_tracing:
  251. for x in inputs:
  252. h = getattr(x, "_mixin_handle", -1)
  253. if h >= 0:
  254. self._tinfo[h].data = True
  255. return
  256. ihandles = []
  257. for x in inputs:
  258. h = getattr(x, "_mixin_handle", -1)
  259. if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
  260. h, info = self._new_handle()
  261. name = (
  262. auto_naming.get_scope() + "." + (x.c_name if x.c_name else x._name)
  263. )
  264. info.name = name
  265. info.external = True
  266. info.device = x.device
  267. info.dtype = x.dtype
  268. info.shape = x.shape
  269. if self._capture_as_const:
  270. info.bound_data = RawTensor(
  271. x.numpy(), x.dtype, x.device, False, name
  272. )
  273. ihandles.append(h)
  274. ohandles = []
  275. for x in outputs:
  276. h, info = self._new_handle()
  277. ohandles.append(h)
  278. info.external = False
  279. x._mixin_handle = h
  280. x._recording = True
  281. x._trace_mixin_info = info
  282. self._active_tensors[h] = TensorWeakRef(x)
  283. if self._symbolic:
  284. self._lazy_eval_tensors[h] = TensorWeakRef(x)
  285. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  286. def _record_const(self, outputs):
  287. if skip_tracing:
  288. (x,) = outputs
  289. h = getattr(x, "_mixin_handle", -1)
  290. if h >= 0:
  291. self._tinfo[h].data_read = True
  292. return
  293. (x,) = outputs
  294. h, info = self._new_handle()
  295. ohandles = [h]
  296. info.external = True
  297. info.device = x.device
  298. info.dtype = x.dtype
  299. info.shape = x.shape
  300. info.bound_data = x
  301. info.is_const = True
  302. x._mixin_handle = h
  303. x._recording = True
  304. x._trace_mixin_info = info
  305. if self._symbolic:
  306. self._lazy_eval_tensors[h] = TensorWeakRef(x)
  307. self._seq.append(("Const", tuple(), tuple(ohandles)))
  308. def _set_active(self, active: bool):
  309. global active_trace
  310. if active:
  311. if active_trace:
  312. raise NotImplementedError("sorry, not implemented: nested trace")
  313. active_trace = self
  314. else:
  315. assert active_trace is self
  316. active_trace = None
  317. def _init_trace(self, symbolic: bool):
  318. if symbolic:
  319. self._lazy_eval_graph = G.Graph()
  320. self._apply_graph_options(self._lazy_eval_graph)
  321. self._lazy_eval_links = ()
  322. def _take_escaped_tensors(self):
  323. escaped_tensors = tuple(
  324. filter(lambda x: x() is not None, self._active_tensors.values())
  325. )
  326. self._active_tensors.clear()
  327. return escaped_tensors
  328. def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
  329. lazy_eval_tensors = list(
  330. filter(lambda x: x() is not None, lazy_eval_tensors.values())
  331. )
  332. readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
  333. self._apply_graph_options(lazy_eval_graph)
  334. # FIXME
  335. if self._graph_opt_level is not None:
  336. lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
  337. else:
  338. lazy_eval_graph.options.graph_opt_level = 2
  339. lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
  340. lazy_eval_graph.compile(*lazy_eval_links, *readers)
  341. lazy_eval_graph()
  342. for r, x in zip(readers, lazy_eval_tensors):
  343. # get values from lazy_eval_graph and assign to lazy_eval tensor
  344. x()._handle = RawTensor(r.op.get_value())._handle
  345. x()._reset_varnode()
  346. @contextlib.contextmanager
  347. def _setup(self):
  348. interrupted = False
  349. def do_enter():
  350. set_tracing()
  351. self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
  352. self._set_active(True)
  353. if self._untraced:
  354. self._init_trace(self._symbolic)
  355. else:
  356. set_compiled()
  357. if self._graph is None:
  358. self._compile()
  359. self._graph.execute()
  360. def do_finalize():
  361. escaped_tensors = self._take_escaped_tensors()
  362. if self._untraced:
  363. for x in escaped_tensors:
  364. if x():
  365. info = self._tinfo[x()._mixin_handle]
  366. info.data_read = True
  367. x()._mixin_handle = -1
  368. x()._recording = False
  369. if self._inputs_to_restore:
  370. for x in self._inputs_to_restore:
  371. x._mixin_handle = -1
  372. x._recording = False
  373. if self._symbolic and (
  374. self._lazy_eval_tensors or self._lazy_eval_links
  375. ):
  376. # eval lazy eval tensors
  377. self._lazy_eval(
  378. self._lazy_eval_graph,
  379. self._lazy_eval_tensors,
  380. self._lazy_eval_links,
  381. )
  382. self._lazy_eval_graph = None
  383. self._lazy_eval_tensors = None
  384. self._lazy_eval_links = None
  385. self._untraced = False
  386. else:
  387. # compiled_tensor leaks
  388. if self._pc == len(self._seq):
  389. for x in escaped_tensors:
  390. try:
  391. assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
  392. except RuntimeError:
  393. # TraceMismatchError thrown in do_exit
  394. pass
  395. self._graph.wait()
  396. self._reset_exec_env()
  397. # reset status
  398. self._pc = 0
  399. self._tensor_remaps = None
  400. self._set_active(False)
  401. set_symbolic_shape(self._save_symbolic_shape)
  402. unset_compiled()
  403. unset_tracing()
  404. def do_exit():
  405. unset_tracing()
  406. if not self._untraced and self._pc != len(self._seq):
  407. raise TraceMismatchError("premature end")
  408. if not self._symbolic or not self._untraced:
  409. # reset output tensors
  410. for x in self._active_tensors.values():
  411. if x() is not None:
  412. x()._dev_tensor()
  413. x()._reset_varnode()
  414. x()._mixin_handle = -1
  415. x()._recording = False
  416. x()._trace_mixin_info = None
  417. try:
  418. do_enter()
  419. yield
  420. do_exit()
  421. except:
  422. interrupted = True
  423. raise
  424. finally:
  425. do_finalize()
  426. if interrupted:
  427. self._reset()
  428. def _begin_excluded_region(self):
  429. if self._capture_as_const:
  430. raise RuntimeError(
  431. "exclude_from_trace cannot be used with capture_as_const"
  432. )
  433. if self._untraced:
  434. # conditionally reading a compiled tensor in excluded region
  435. # is permitted, so we have to assume every tensor might be read
  436. for x in self._active_tensors.values():
  437. if x():
  438. info = self._tinfo[x()._mixin_handle]
  439. info.exported = True
  440. info.data_read = True
  441. else:
  442. for x in self._active_tensors.values():
  443. if x():
  444. x()._dev_tensor()
  445. def _apply_graph_options(self, graph):
  446. graph.options.no_force_inplace = True
  447. graph.options.seq_opt.enable_seq_comp_node_opt = False
  448. # graph opt level
  449. # if self._graph_opt_level is not None:
  450. # graph.options.graph_opt_level = self._graph_opt_level
  451. # FIXME
  452. graph.options.graph_opt_level = 0
  453. # sublinear
  454. if self._sublinear_memory_config is not None:
  455. graph.options.enable_sublinear_memory_opt = True
  456. sublinear_config = graph.options.sublinear_mem_config
  457. sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
  458. sublinear_config.genetic_nr_iter = (
  459. self._sublinear_memory_config.genetic_nr_iter
  460. )
  461. sublinear_config.genetic_pool_size = (
  462. self._sublinear_memory_config.genetic_pool_size
  463. )
  464. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  465. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  466. # profile
  467. if self._profiling:
  468. self._profiler = GraphProfiler(graph)
  469. if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
  470. graph.options.var_sanity_check_first_run = False
  471. def _compile(self):
  472. graph = self._graph = G.Graph()
  473. graph.options.async_exec_level = 0b100
  474. self._apply_graph_options(graph)
  475. # graph.options.graph_opt_level = 0
  476. need_reset_nodes = self._need_reset_nodes = []
  477. # links enforce ordering of I/O nodes
  478. in_out_links = ()
  479. io_links = ()
  480. readers = []
  481. if self._capture_as_const:
  482. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  483. info = self._tinfo[h]
  484. opnode = info.data_setter = G.InputNode(
  485. device=info.device,
  486. dtype=info.dtype,
  487. shape=info.shape or (1,),
  488. graph=graph,
  489. use_static_shape=_input_node_use_static_shape(),
  490. )
  491. need_reset_nodes.append(opnode)
  492. info.varnode = opnode.outputs[0]
  493. in_out_links += opnode.outputs[1:]
  494. for op, ihandles, ohandles in self._seq:
  495. if isinstance(op, str) and op == "Const":
  496. assert len(ihandles) == 0
  497. (h,) = ohandles
  498. info = self._tinfo[h]
  499. if not hasattr(info, "varnode"):
  500. assert info.external
  501. assert info.bound_data
  502. info.varnode = graph.make_const(
  503. info.bound_data.numpy(),
  504. info.bound_data.dtype,
  505. info.bound_data.device,
  506. )
  507. continue
  508. require_links = type(op) in _io_op_types
  509. ivars = []
  510. for i, h in enumerate(ihandles):
  511. info = self._tinfo[h]
  512. if not hasattr(info, "varnode"):
  513. assert info.external
  514. if info.bound_data:
  515. if hasattr(info, "is_const") and info.is_const:
  516. info.varnode = graph.make_const(
  517. info.bound_data.numpy(),
  518. info.bound_data.dtype,
  519. info.bound_data.device,
  520. )
  521. else:
  522. info.varnode = graph.make_const(
  523. info.bound_data._dev_tensor()
  524. # info.bound_data.numpy()
  525. )
  526. else:
  527. opnode = info.data_setter = G.InputNode(
  528. *in_out_links,
  529. device=info.device,
  530. dtype=info.dtype,
  531. shape=info.shape or (1,),
  532. graph=graph,
  533. use_static_shape=_input_node_use_static_shape(),
  534. )
  535. need_reset_nodes.append(opnode)
  536. info.varnode, *in_out_links = opnode.outputs
  537. if require_links and i == 0 and len(io_links) > 0:
  538. opnode = G.VirtualDepNode(
  539. [info.varnode, *io_links], str(io_links[0].device)
  540. )
  541. info.varnode = opnode.outputs[0]
  542. io_links = (info.varnode,)
  543. ivars.append(info.varnode)
  544. if isinstance(op, BackwardGraph):
  545. ovars = G.apply_backward_varnode(op, *ivars)
  546. else:
  547. ovars = G.apply_normal_varnode(op, *ivars)
  548. if require_links and len(ovars) > 0:
  549. io_links = (ovars[0],)
  550. assert len(ovars) == len(ohandles)
  551. for h, v in zip(ohandles, ovars):
  552. info = self._tinfo[h]
  553. info.varnode = v
  554. def add_reader(opnode):
  555. nonlocal in_out_links
  556. need_reset_nodes.append(opnode)
  557. readers.append(opnode.outputs[0])
  558. in_out_links = opnode.outputs
  559. if info.data_read:
  560. # Shape can be obtained from data so doesn't need its own
  561. # output node. On the other hand, value is read separately
  562. # to leverage eager h2d copy
  563. info.shape_read = False
  564. opnode = info.data_reader = G.OutputNode(v, *in_out_links)
  565. add_reader(opnode)
  566. if info.value_read:
  567. opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
  568. add_reader(opnode)
  569. if info.shape_read:
  570. opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
  571. add_reader(opnode)
  572. # FIXME
  573. if self._graph_opt_level is not None:
  574. graph.options.graph_opt_level = self._graph_opt_level
  575. else:
  576. graph.options.graph_opt_level = 2
  577. graph._set_priority_to_id([*readers, *in_out_links, *io_links])
  578. graph.compile(*readers, *in_out_links, *io_links)
  579. def _reset_exec_env(self):
  580. for opnode in self._need_reset_nodes:
  581. opnode.reset()
  582. def __call__(self, *args, **kwargs):
  583. if is_tracing():
  584. return self.__wrapped__(*args, **kwargs)
  585. with self._setup():
  586. if self._capture_as_const:
  587. self._process_inputs(*args, **kwargs)
  588. outputs = self.__wrapped__(*args, **kwargs)
  589. if self._capture_as_const:
  590. self._process_outputs(outputs)
  591. # outputs could be None
  592. if outputs is not None:
  593. list_outputs = outputs
  594. if isinstance(outputs, collections.abc.Mapping):
  595. _, list_outputs = zip(*sorted(outputs.items()))
  596. elif not isinstance(outputs, collections.abc.Sequence):
  597. list_outputs = (outputs,)
  598. for o in list_outputs:
  599. # if outputs are copied, then use the newest info in trace data structure
  600. if o._copied:
  601. self._active_tensors[o._mixin_handle] = TensorWeakRef(o)
  602. if self._untraced and self._symbolic:
  603. self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o)
  604. return outputs
  605. def dump(
  606. self,
  607. file,
  608. *,
  609. arg_names=None,
  610. output_names=None,
  611. append=False,
  612. keep_var_name: int = 1,
  613. keep_opr_name: bool = False,
  614. keep_param_name: bool = False,
  615. keep_opr_priority: bool = False,
  616. strip_info_file=None,
  617. append_json=False,
  618. optimize_for_inference=True,
  619. **kwargs
  620. ):
  621. r"""
  622. Serializes trace to file system.
  623. :param file: output file, could be file object or filename.
  624. :param arg_names: names of the input tensors in the traced function.
  625. :param output_names: names of the output tensors in the traced function,
  626. use the default name if not specified.
  627. :param append: whether output is appended to ``file``.
  628. Only works when ``file`` is str.
  629. :param keep_var_name: level for keeping variable names:
  630. * 0: none of the names are kept
  631. * 1: (default)keep names of output vars
  632. * 2: keep names of all (output and internal) vars
  633. :param keep_opr_name: whether to keep operator names.
  634. :param keep_param_name: whether to keep param names, so param values can be
  635. easily manipulated after loading model
  636. :param keep_opr_priority: whether to keep priority setting for operators
  637. :param strip_info_file: a string for path or a file handler. if is not None,
  638. then the dump information for code strip would be written to ``strip_info_file``
  639. :param append_json: will be check when `strip_info_file` is not None. if set
  640. true, the information for code strip will be append to strip_info_file.
  641. if set false, will rewrite strip_info_file
  642. :param optimize_for_inference: enbale optmizations,
  643. will skip all optimize options if this is False. Default: True
  644. :Keyword Arguments:
  645. * enable_io16xc32 --
  646. whether to use float16 for I/O between oprs and use
  647. float32 as internal computation precision. Note the output var would be
  648. changed to float16.
  649. * enable_ioc16 --
  650. whether to use float16 for both I/O and computation
  651. precision.
  652. * enable_hwcd4 --
  653. whether to use NHWCD4 data layout. This is faster on some
  654. OpenCL backend.
  655. * enable_nchw88 --
  656. whether to use NCHW88 data layout, currently
  657. used in X86 AVX backend.
  658. * enable_nchw44 --
  659. whether to use NCHW44 data layout, currently
  660. used in arm backend.
  661. * enable_nchw44_dot --
  662. whether to use NCHW44_dot data layout, currently
  663. used in armv8.2+dotprod backend.
  664. * enable_nchw4 --
  665. whether to use NCHW4 data layout, currently
  666. used in nvidia backend(based on cudnn).
  667. * enable_nchw32 --
  668. whether to use NCHW32 data layout, currently
  669. used in nvidia backend with tensorcore(based on cudnn).
  670. * enable_chwn4 --
  671. whether to use CHWN4 data layout, currently
  672. used in nvidia backend with tensorcore.
  673. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  674. into one opr.
  675. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  676. input for inference on nvidia backend(this optimization pass will
  677. result in mismatch of the precision of output of training and
  678. inference)
  679. """
  680. if not self._capture_as_const:
  681. raise ValueError(
  682. "you must specify capture_as_const=True at __init__ to use dump"
  683. )
  684. if self._untraced:
  685. raise RuntimeError("should run at least once before calling dump")
  686. if self._output_names and output_names:
  687. raise TypeError(
  688. "cannot specify output_names when output is already in dict format"
  689. )
  690. if output_names and not isinstance(output_names, collections.abc.Sequence):
  691. output_names = (output_names,)
  692. if output_names and len(output_names) != len(self._output_bindings):
  693. raise ValueError(
  694. "wrong number of output_names, should be {} values".format(
  695. len(self._output_bindings)
  696. )
  697. )
  698. if arg_names is None:
  699. arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
  700. if arg_names and not isinstance(arg_names, collections.abc.Sequence):
  701. arg_names = (arg_names,)
  702. if arg_names and len(arg_names) != len(self._arg_bindings):
  703. raise ValueError(
  704. "wrong number of arg_names, should be {} values".format(
  705. len(self._arg_bindings)
  706. )
  707. )
  708. output_names = output_names or self._output_names
  709. dumped_device = as_device("xpux")
  710. h2v = {}
  711. graph = G.Graph()
  712. # apply graph_opt_level in dump
  713. if self._graph_opt_level is not None:
  714. graph.options.graph_opt_level = self._graph_opt_level
  715. for i, h in enumerate(self._arg_bindings):
  716. info = self._tinfo[h]
  717. h2v[h] = graph.make_h2d(
  718. dtype=info.dtype,
  719. device=dumped_device,
  720. shape=info.shape or (1,),
  721. name=arg_names[i] if arg_names else None,
  722. )
  723. for k, h in self._kwarg_bindings.items():
  724. info = self._tinfo[h]
  725. h2v[h] = graph.make_h2d(
  726. dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
  727. )
  728. for op, ihandles, ohandles in self._seq:
  729. if isinstance(op, str) and op == "Const":
  730. assert len(ihandles) == 0
  731. (h,) = ohandles
  732. info = self._tinfo[h]
  733. if h not in h2v:
  734. assert info.external
  735. assert info.bound_data
  736. h2v[h] = graph.make_const(
  737. info.bound_data.numpy(),
  738. dtype=info.dtype,
  739. device=info.device,
  740. name=info.name,
  741. )
  742. continue
  743. ivars = []
  744. for h in ihandles:
  745. info = self._tinfo[h]
  746. if h not in h2v:
  747. assert info.external
  748. assert info.bound_data
  749. h2v[h] = graph.make_const(
  750. info.bound_data.numpy(),
  751. dtype=info.dtype,
  752. device=dumped_device,
  753. name=info.name,
  754. )
  755. ivars.append(h2v[h])
  756. ovars = G.apply_normal_varnode(op, *ivars)
  757. auto_naming.record_opnode(ovars[0].op)
  758. assert len(ovars) == len(ohandles)
  759. h2v.update(zip(ohandles, ovars))
  760. for i in ohandles:
  761. name = auto_naming.get_var_name(i)
  762. if name is not None:
  763. h2v[i].name = name
  764. auto_naming.remove_duplicate_names()
  765. dest_vars = []
  766. for i, h in enumerate(self._output_bindings):
  767. v = h2v[h]
  768. if output_names:
  769. v.name = output_names[i]
  770. dest_vars.append(v)
  771. if optimize_for_inference:
  772. dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
  773. if isinstance(file, str):
  774. permission = "wb" if append == False else "ab"
  775. file = open(file, permission)
  776. dump_content, dump_info = G.dump_graph(
  777. dest_vars,
  778. keep_var_name=keep_var_name,
  779. keep_opr_name=keep_opr_name,
  780. keep_param_name=keep_param_name,
  781. keep_opr_priority=keep_opr_priority,
  782. strip_info_file=strip_info_file,
  783. append_json=append_json,
  784. )
  785. file.write(dump_content)
  786. return dump_info
  787. def _process_inputs(self, *args, **kwargs):
  788. if self._untraced:
  789. self._inputs_to_restore = []
  790. def record_input(x):
  791. if x is None:
  792. return
  793. h, info = self._new_handle()
  794. info.external = False
  795. info.device = x.device
  796. info.dtype = x.dtype
  797. info.shape = x.numpy().shape
  798. x._mixin_handle = h
  799. x._recording = True
  800. x._trace_mixin_info = info
  801. self._inputs_to_restore.append(x)
  802. return h
  803. self._arg_bindings = []
  804. for i, x in enumerate(args):
  805. if not isinstance(x, RawTensor):
  806. raise TypeError(
  807. "positional arguments should all be tensor "
  808. "but args[%d] cannot be recognized as one" % i
  809. )
  810. self._arg_bindings.append(record_input(x))
  811. self._kwarg_bindings = {}
  812. for k, x in kwargs.items():
  813. if isinstance(x, RawTensor):
  814. self._kwarg_bindings[k] = record_input(x)
  815. else:
  816. if len(args) != len(self._arg_bindings):
  817. raise TraceMismatchError("positional argument length mismatch")
  818. self._tensor_remaps = {}
  819. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  820. if not isinstance(x, RawTensor):
  821. raise TypeError(
  822. "positional arguments should all be tensor "
  823. "but args[%d] cannot be recognized as one" % i
  824. )
  825. info = self._tinfo[h]
  826. if x.dtype != info.dtype:
  827. raise TypeError("args[%d].dtype different from last time" % i)
  828. if x.device != info.device:
  829. raise TypeError("args[%d].device different from last time" % i)
  830. info.data_setter.set_value(x._dev_tensor())
  831. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  832. kwargs_tensors = {}
  833. for k, x in kwargs.items():
  834. if isinstance(x, RawTensor):
  835. kwargs_tensors[k] = x
  836. if set(kwargs_tensors) != set(self._kwarg_bindings):
  837. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  838. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  839. if too_many:
  840. raise TraceMismatchError(
  841. "keyword arguments found to be tensor this time "
  842. "but were non-tensor previously: %s" % " ".join(too_many)
  843. )
  844. if too_few:
  845. raise TraceMismatchError(
  846. "keyword arguments found to be non-tensor this time "
  847. "but were tensor previously: %s" % " ".join(too_few)
  848. )
  849. for k, h in self._kwarg_bindings.items():
  850. x = kwargs_tensors[k]
  851. info = self._tinfo[h]
  852. if x.dtype != info.dtype:
  853. raise TypeError("kwargs[%s].dtype different from last time" % k)
  854. if x.device != info.device:
  855. raise TypeError("kwargs[%s].device different from last time" % k)
  856. info.data_setter.set_value(x._dev_tensor())
  857. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  858. def _process_outputs(self, outputs):
  859. output_names = None
  860. if isinstance(outputs, collections.abc.Mapping):
  861. output_names, outputs = zip(*sorted(outputs.items()))
  862. elif not isinstance(outputs, collections.abc.Sequence):
  863. outputs = (outputs,)
  864. if not self._untraced:
  865. if output_names != self._output_names:
  866. too_many = set(output_names) - set(self._output_names)
  867. too_few = set(self._output_names) - set(output_names)
  868. if too_many:
  869. raise TraceMismatchError(
  870. "output has more keys than last time: %s" % " ".join(too_many)
  871. )
  872. if too_few:
  873. raise TraceMismatchError(
  874. "output has less keys than last time: %s" % " ".join(too_few)
  875. )
  876. if len(outputs) != len(self._output_bindings):
  877. raise TraceMismatchError("output size differs from last time")
  878. else:
  879. self._output_names = output_names
  880. self._output_bindings = []
  881. for i, x in enumerate(outputs):
  882. if not isinstance(x, RawTensor):
  883. raise TypeError("every item of return value should be tensor")
  884. if self._untraced:
  885. h = x._mixin_handle
  886. if h < 0:
  887. raise RuntimeError("output is not computed from inputs")
  888. self._output_bindings.append(h)
  889. else:
  890. h = x._mixin_handle
  891. if h not in self._output_handles:
  892. raise RuntimeError("output is not computed from inputs")
  893. if h != self._output_bindings[i]:
  894. raise TraceMismatchError(
  895. "retval[%s] is a different tensor than last time"
  896. % (output_names and output_names[i] or i)
  897. )
  898. def get_profile(self):
  899. """
  900. Get profiling result for compiled trace.
  901. :return: a json compatible object.
  902. """
  903. if not self._profiler:
  904. raise RuntimeError("trace is not set with profiling=True")
  905. return json.loads(self._profiler.get())
  906. def __del__(self):
  907. for x in self._tinfo:
  908. if getattr(x, "bound_data", None):
  909. x.bound_data = None
  910. def trace(self, *args, **kwargs):
  911. raise NotImplementedError(
  912. "trace is deemed unbeneficial with the new "
  913. "tracing mechanism. You should alwasy use __call__."
  914. )
  915. class CompiledTensorProxy:
  916. """
  917. Duck-typed RawTensor
  918. """
  919. def __init__(self, handle):
  920. self.__handle = handle
  921. self._isscalar = False
  922. self.__info = active_trace._tinfo[handle]
  923. self.__shape = None
  924. self.__data = None
  925. self.__value = None
  926. @property
  927. def dtype(self):
  928. return self.__info.varnode.dtype
  929. @property
  930. def device(self):
  931. return self.__info.varnode.device
  932. @property
  933. def shape(self):
  934. if self._isscalar:
  935. return ()
  936. if self.__shape is None:
  937. if self.__info.shape_read:
  938. self.__shape = self.__info.shape_reader.get_value().shape
  939. elif self.__info.data_read:
  940. self.__shape = self._dev_tensor().shape
  941. else:
  942. # c++ will throw TraceReadError
  943. return None
  944. return self.__shape
  945. def numpy(self):
  946. if self.__value is None:
  947. if self.__info.value_read:
  948. self.__value = self.__info.value_reader.get_value()
  949. elif self.__info.data_read:
  950. self.__value = self._dev_tensor().numpy()
  951. else:
  952. # c++ will throw TraceReadError
  953. return None
  954. # c++ side will handle scalar case
  955. return self.__value
  956. def _dev_tensor(self):
  957. if self.__data is None:
  958. if not self.__info.data_read:
  959. # c++ will throw TraceReadError
  960. return None
  961. self.__data = self.__info.data_reader.get_value()
  962. return self.__data
  963. def __del__(self):
  964. if self.__info.shape_read and self.__shape is not None:
  965. self.__info.shape_reader.drop_value()
  966. if self.__info.value_read and self.__value is not None:
  967. self.__info.value_reader.drop_value()
  968. if self.__info.data_read and self.__data is not None:
  969. self.__info.data_reader.drop_value()
  970. def assign_raw_tensor(lhs, rhs):
  971. lhs.__init__(rhs)
  972. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  973. graph = active_trace._lazy_eval_graph
  974. ivars = []
  975. for x in args:
  976. var = getattr(x, "_varnode", None)
  977. if var:
  978. ivars.append(var)
  979. else:
  980. data_setter = G.InputNode(
  981. device=x.device,
  982. dtype=x.dtype,
  983. shape=x.numpy().shape or (1,),
  984. graph=graph,
  985. use_static_shape=True,
  986. )
  987. var = data_setter.outputs[0]
  988. ivars.append(var)
  989. data_setter.set_value(x._dev_tensor())
  990. require_links = type(op) in _io_op_types
  991. if require_links and active_trace._lazy_eval_links:
  992. assert len(ivars) > 0, "op should has at least one input"
  993. opnode = G.VirtualDepNode(
  994. [ivars[0], *active_trace._lazy_eval_links],
  995. str(active_trace._lazy_eval_links[0].device),
  996. )
  997. ivars[0] = opnode.outputs[0]
  998. active_trace._lazy_eval_links = (ivars[0],)
  999. if isinstance(op, BackwardGraph):
  1000. ovars = G.apply_backward_varnode(op, *ivars)
  1001. else:
  1002. ovars = G.apply_normal_varnode(op, *ivars)
  1003. outputs = [RawTensor(o) for o in ovars]
  1004. if require_links:
  1005. active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
  1006. return outputs
  1007. def apply_const_symbolic_mode(value, dtype, device, name):
  1008. graph = active_trace._lazy_eval_graph
  1009. # don't need to unset tracing
  1010. # because varnode construction will ignore tracing flag
  1011. ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name))
  1012. if np.array(value).ndim == 0:
  1013. setscalar(ret)
  1014. return (ret,)
  1015. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  1016. if skip_tracing:
  1017. args = [
  1018. RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  1019. for x in args
  1020. ]
  1021. unset_tracing()
  1022. ret = apply(op, *args)
  1023. set_tracing()
  1024. return ret
  1025. return active_trace._apply_op(op, args)
  1026. def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
  1027. if skip_tracing:
  1028. args = [
  1029. RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  1030. for x in args
  1031. ]
  1032. unset_tracing()
  1033. ret = RawTensor(value, dtype, device, False, name)
  1034. set_tracing()
  1035. return ret
  1036. return active_trace._apply_const(value, dtype, device)
  1037. def apply_with_tracing(op: OpDef, *args: RawTensor):
  1038. if hasattr(op, "scope"):
  1039. op.scope = auto_naming.get_scope()
  1040. if active_trace._symbolic:
  1041. outputs = apply_symbolic_mode(op, *args)
  1042. else:
  1043. unset_tracing()
  1044. outputs = apply(op, *args)
  1045. set_tracing()
  1046. active_trace._record_op(op, args, outputs)
  1047. return list(outputs)
  1048. def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
  1049. if active_trace._symbolic:
  1050. outputs = apply_const_symbolic_mode(value, dtype, device, name)
  1051. else:
  1052. unset_tracing()
  1053. outputs = (RawTensor(value, dtype, device, False, name),)
  1054. set_tracing()
  1055. active_trace._record_const(outputs)
  1056. return list(outputs)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台