You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 46 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import contextlib
  11. import functools
  12. import itertools
  13. import json
  14. import os
  15. import pickle
  16. from typing import Any
  17. import numpy as np
  18. from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata
  19. from ..core._imperative_rt.core2 import Tensor as RawTensor
  20. from ..core._imperative_rt.core2 import (
  21. TensorWeakRef,
  22. apply,
  23. set_tracing,
  24. skip_tracing,
  25. unset_tracing,
  26. )
  27. from ..core._imperative_rt.ops import (
  28. AssertEqual,
  29. CollectiveComm,
  30. ExternOpr,
  31. RemoteRecv,
  32. RemoteSend,
  33. )
  34. from ..core._trace_option import set_symbolic_shape
  35. from ..core._wrap import as_device
  36. from ..core.ops.builtin import BatchNorm, OpDef
  37. from ..core.tensor import megbrain_graph as G
  38. from ..core.tensor.utils import setscalar
  39. from ..utils.naming import AutoNaming
  40. from ..utils.profiler import is_profiling
  41. from .dtr_config import DTRConfig
  42. from .graph_opt_config import GraphOptimizationConfig
  43. from .sublinear_memory_config import SublinearMemoryConfig
  44. def _input_node_use_static_shape():
  45. return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
  46. class TraceMismatchError(RuntimeError):
  47. pass
  48. active_trace = None
  49. def is_tracing():
  50. if active_trace is None:
  51. return False
  52. else:
  53. return not skip_tracing
  54. @contextlib.contextmanager
  55. def exclude_from_trace():
  56. global skip_tracing
  57. if skip_tracing or (active_trace is None):
  58. yield
  59. return
  60. try:
  61. skip_tracing = True
  62. unset_tracing()
  63. if active_trace is not None:
  64. active_trace._begin_excluded_region()
  65. yield
  66. finally:
  67. skip_tracing = False
  68. set_tracing()
  69. class TensorInfo:
  70. __slots__ = (
  71. # collected attributes
  72. "name",
  73. "external",
  74. "data_read",
  75. "shape_read",
  76. "value_read",
  77. "exported",
  78. "device",
  79. "dtype",
  80. "shape",
  81. "is_const",
  82. "bound_data",
  83. # resources for execution
  84. "varnode",
  85. "data_setter",
  86. "shape_reader",
  87. "value_reader",
  88. "data_reader",
  89. )
  90. def __init__(self):
  91. self.name = None
  92. self.exported = None
  93. self.data_read = None
  94. self.shape_read = None
  95. self.value_read = None
  96. self.bound_data = None
  97. self.data_setter = None
  98. self.shape_reader = None
  99. self.value_reader = None
  100. self.data_reader = None
  101. _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}
  102. class trace:
  103. """Wraps a callable and provide:
  104. * tracing via :meth:`.trace` and :meth:`.dump`
  105. * accelerated evalutaion via :meth:`.__call__`
  106. Args:
  107. function: the function will be traced.
  108. symbolic: whether to apply symbolic execution for tracing. Default: False
  109. capture_as_const: capture global vars or closures as const value. Default: False
  110. record_only: if True, won't run even if call the function. Default: False
  111. sublinear_memory_config: configuration for sublinear memory optimization.
  112. If not None, it enables sublinear memory optimization with given setting.
  113. profiling: whether to profile compiled trace. Default: False
  114. opt_level: optimization level for compiling trace. Default: 2
  115. graph_opt_config: configuration for graph optimization. Default: None
  116. symbolic_shape: whether to use symbolic shape for tracing. Default: True
  117. """
  118. def __new__(cls, *args, **kwargs):
  119. if not args:
  120. return functools.partial(cls, **kwargs)
  121. return super().__new__(cls)
  122. def __init__(
  123. self,
  124. function,
  125. symbolic=False,
  126. capture_as_const=False,
  127. record_only=False,
  128. sublinear_memory_config: SublinearMemoryConfig = None,
  129. dtr_config: DTRConfig = None,
  130. profiling: bool = False,
  131. opt_level: int = 2,
  132. graph_opt_config: GraphOptimizationConfig = None,
  133. symbolic_shape: bool = True,
  134. ):
  135. self.__wrapped__ = function
  136. self._symbolic = symbolic or record_only
  137. self._capture_as_const = capture_as_const or record_only
  138. self._record_only = record_only
  139. self._sublinear_memory_config = sublinear_memory_config
  140. self._dtr_config = dtr_config
  141. self._profiling = profiling
  142. self._profiler = None
  143. self._profiler2 = None
  144. self._graph_opt_level = opt_level
  145. self._graph_opt_config = graph_opt_config
  146. self._symbolic_shape = symbolic_shape
  147. self._output_handles = set()
  148. self._reset()
  149. def _reset(self):
  150. self._untraced = True
  151. self._tinfo = [] # handle -> TensorInfo
  152. self._seq = []
  153. self._pc = 0
  154. self._graph = None
  155. self._need_reset_nodes = None
  156. self._lazy_eval_graph = None
  157. self._lazy_eval_tensors = set()
  158. self._lazy_eval_links = None
  159. self._active_tensors = set()
  160. self._tensor_remaps = None
  161. self._inputs_to_restore = None
  162. self._arg_bindings = None
  163. self._kwarg_bindings = None
  164. self._output_bindings = None
  165. self._output_names = None
  166. def _new_handle(self):
  167. handle = len(self._tinfo)
  168. info = TensorInfo()
  169. self._tinfo.append(info)
  170. return handle, info
  171. def _apply_op(self, op, args):
  172. assert not self._untraced
  173. # check against trace
  174. if self._pc >= len(self._seq):
  175. raise TraceMismatchError("trace should end here, but more op observed")
  176. record = self._seq[self._pc]
  177. op_, ihandles, ohandles = record
  178. if (isinstance(op_, str) and op_ == "Const") or (op != op_):
  179. raise TraceMismatchError("op different from last time")
  180. if len(ihandles) != len(args):
  181. raise TraceMismatchError("op input size different from last time")
  182. # check all inputs of crrent op
  183. for h, x in zip(ihandles, args):
  184. info = self._tinfo[h]
  185. if info.external:
  186. if (
  187. x._compiled_info is not None
  188. and not self._tinfo[x._mixin_handle].exported
  189. ):
  190. raise TraceMismatchError(
  191. "failed to capture: input was an external tensor "
  192. "last time, got an internal tensor this time"
  193. )
  194. if info.bound_data:
  195. if x._compiled_info is not None:
  196. raise TraceMismatchError(
  197. "const capture violated: was an external tensor "
  198. "last time, got an internal tensor this time"
  199. )
  200. if x._handle != info.bound_data._handle:
  201. if not np.array_equal(x.numpy(), info.bound_data.numpy()):
  202. raise TraceMismatchError(
  203. "const capture violated: got "
  204. "a different tensor this time"
  205. )
  206. else:
  207. if info.dtype != x.dtype:
  208. raise TraceMismatchError(
  209. "failed to capture: different dtype from last time"
  210. )
  211. if info.device != x.device:
  212. raise TraceMismatchError(
  213. "failed to capture: different device from last time"
  214. )
  215. info.data_setter.set_value(x._dev_tensor())
  216. else:
  217. if x._mixin_handle == -1:
  218. if x._handle not in self._tensor_remaps:
  219. raise TraceMismatchError(
  220. "unexpected capture: trying to use an external tensor as "
  221. "input, but that input was an internal tensor last time"
  222. )
  223. else:
  224. x._mixin_handle = self._tensor_remaps[
  225. x._handle
  226. ]._CompiledTensorProxy__handle
  227. if x._mixin_handle != h:
  228. raise TraceMismatchError(
  229. "mis-wiring: input edge to an data flow "
  230. "graph node is different from last time"
  231. )
  232. self._pc += 1
  233. outputs = []
  234. for h in ohandles:
  235. info = self._tinfo[h]
  236. # generate output tensor and create compied info
  237. y = RawTensor(info.varnode)
  238. y._compiled_info = CompiledTensorProxy(h)
  239. y._mixin_handle = h
  240. outputs += [y]
  241. self._active_tensors.add(TensorWeakRef(y))
  242. self._output_handles.update(ohandles)
  243. return outputs
  244. def _apply_const(self, value, dtype, device):
  245. assert not self._untraced
  246. # check against trace
  247. if self._pc >= len(self._seq):
  248. raise TraceMismatchError("trace should end here, but more op observed")
  249. record = self._seq[self._pc]
  250. op_, ihandles, ohandles = record
  251. # Const op is represented by a str
  252. assert isinstance(op_, str) and op_ == "Const"
  253. expected = self._tinfo[ohandles[0]].bound_data.numpy()
  254. shape = value.shape
  255. if shape != expected.shape or dtype != expected.dtype:
  256. eq = False
  257. elif shape == ():
  258. eq = expected.item() == value.item()
  259. elif shape == (1,):
  260. eq = expected[0] == value[0]
  261. else:
  262. eq = np.all(value == expected)
  263. if not eq:
  264. raise TraceMismatchError(
  265. "const tensor violated: got a different tensor this time"
  266. )
  267. self._pc += 1
  268. (h,) = ohandles
  269. outputs = [self._tinfo[h].bound_data]
  270. return outputs
  271. # run in first step, record information for trace
  272. def _record_op(self, op, inputs, outputs):
  273. if skip_tracing:
  274. for x in inputs:
  275. h = getattr(x, "_mixin_handle", -1)
  276. if h >= 0:
  277. self._tinfo[h].data = True
  278. return
  279. ihandles = []
  280. for x in inputs:
  281. h = getattr(x, "_mixin_handle", -1)
  282. if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
  283. h, info = self._new_handle()
  284. name = AutoNaming.gen_name(x)
  285. info.name = name
  286. info.external = True
  287. info.device = x.device
  288. info.dtype = x.dtype
  289. info.shape = x.shape
  290. if self._capture_as_const:
  291. info.bound_data = RawTensor(
  292. x.numpy(), x.dtype, x.device, False, name
  293. )
  294. ihandles.append(h)
  295. ohandles = []
  296. for x in outputs:
  297. h, info = self._new_handle()
  298. ohandles.append(h)
  299. info.external = False
  300. x._mixin_handle = h
  301. x._recording = True
  302. x._trace_mixin_info = info
  303. self._active_tensors.add(TensorWeakRef(x))
  304. if self._symbolic:
  305. self._lazy_eval_tensors.add(TensorWeakRef(x))
  306. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  307. def _record_const(self, outputs):
  308. if skip_tracing:
  309. (x,) = outputs
  310. h = getattr(x, "_mixin_handle", -1)
  311. if h >= 0:
  312. self._tinfo[h].data_read = True
  313. return
  314. (x,) = outputs
  315. h, info = self._new_handle()
  316. ohandles = [h]
  317. info.external = True
  318. info.device = x.device
  319. info.dtype = x.dtype
  320. info.shape = x.shape
  321. info.bound_data = x
  322. info.is_const = True
  323. x._mixin_handle = h
  324. x._recording = True
  325. x._trace_mixin_info = info
  326. if self._symbolic:
  327. self._lazy_eval_tensors.add(TensorWeakRef(x))
  328. self._seq.append(("Const", tuple(), tuple(ohandles)))
  329. def _set_active(self, active: bool):
  330. global active_trace
  331. if active:
  332. if active_trace:
  333. raise NotImplementedError("sorry, not implemented: nested trace")
  334. active_trace = self
  335. else:
  336. assert active_trace is self
  337. active_trace = None
  338. def _init_trace(self, symbolic: bool):
  339. if symbolic:
  340. self._lazy_eval_graph = G.Graph()
  341. self._apply_graph_options(self._lazy_eval_graph)
  342. self._lazy_eval_links = ()
  343. def _take_escaped_tensors(self):
  344. escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
  345. self._active_tensors.clear()
  346. return escaped_tensors
  347. def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
  348. lazy_eval_tensors = [x() for x in lazy_eval_tensors]
  349. lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None]
  350. readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors]
  351. self._apply_graph_options(lazy_eval_graph)
  352. lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
  353. lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
  354. lazy_eval_graph.compile(*lazy_eval_links, *readers)
  355. self._execute_graph(lazy_eval_graph)
  356. lazy_eval_graph.wait()
  357. for r, x in zip(readers, lazy_eval_tensors):
  358. # get values from lazy_eval_graph and assign to lazy_eval tensor
  359. x._handle = RawTensor(r.op.get_value())._handle
  360. x._reset_varnode()
  361. @contextlib.contextmanager
  362. def _setup(self):
  363. interrupted = False
  364. def do_enter():
  365. set_tracing()
  366. self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
  367. self._set_active(True)
  368. if self._untraced:
  369. self._init_trace(self._symbolic)
  370. else:
  371. if self._graph is None:
  372. self._compile()
  373. self._execute_graph(self._graph)
  374. def do_finalize():
  375. escaped_tensors = self._take_escaped_tensors()
  376. if self._untraced:
  377. if self._record_only:
  378. self._lazy_eval_graph = None
  379. self._lazy_eval_tensors = None
  380. self._lazy_eval_links = None
  381. else:
  382. for x in escaped_tensors:
  383. if x():
  384. info = self._tinfo[x()._mixin_handle]
  385. info.data_read = True
  386. x()._mixin_handle = -1
  387. x()._recording = False
  388. if self._inputs_to_restore:
  389. for x in self._inputs_to_restore:
  390. x._mixin_handle = -1
  391. x._recording = False
  392. if self._symbolic and (
  393. self._lazy_eval_tensors or self._lazy_eval_links
  394. ):
  395. # eval lazy eval tensors
  396. self._lazy_eval(
  397. self._lazy_eval_graph,
  398. self._lazy_eval_tensors,
  399. self._lazy_eval_links,
  400. )
  401. self._lazy_eval_graph = None
  402. self._lazy_eval_tensors = None
  403. self._lazy_eval_links = None
  404. self._untraced = False
  405. else:
  406. # compiled_tensor leaks
  407. if self._pc == len(self._seq):
  408. for x in escaped_tensors:
  409. try:
  410. x().__init__(RawTensor(x()._dev_tensor()))
  411. except RuntimeError:
  412. # TraceMismatchError thrown in do_exit
  413. pass
  414. self._graph.wait()
  415. self._reset_exec_env()
  416. # reset status
  417. self._pc = 0
  418. self._tensor_remaps = None
  419. self._set_active(False)
  420. set_symbolic_shape(self._save_symbolic_shape)
  421. unset_tracing()
  422. def do_exit():
  423. unset_tracing()
  424. if not self._untraced and self._pc != len(self._seq):
  425. raise TraceMismatchError("premature end")
  426. if not self._symbolic or not self._untraced:
  427. # reset output tensors
  428. for x in self._active_tensors.copy():
  429. strong_x = x()
  430. if strong_x is not None:
  431. strong_x._dev_tensor()
  432. strong_x._reset_varnode()
  433. strong_x._mixin_handle = -1
  434. strong_x._recording = False
  435. strong_x._trace_mixin_info = None
  436. try:
  437. do_enter()
  438. yield
  439. do_exit()
  440. except:
  441. interrupted = True
  442. raise
  443. finally:
  444. do_finalize()
  445. if interrupted:
  446. self._reset()
  447. def _begin_excluded_region(self):
  448. if self._capture_as_const:
  449. raise RuntimeError(
  450. "exclude_from_trace cannot be used with capture_as_const"
  451. )
  452. if self._untraced:
  453. # conditionally reading a compiled tensor in excluded region
  454. # is permitted, so we have to assume every tensor might be read
  455. for x in self._active_tensors:
  456. strong_x = x()
  457. if strong_x:
  458. info = self._tinfo[strong_x._mixin_handle]
  459. info.exported = True
  460. info.data_read = True
  461. else:
  462. for x in self._active_tensors:
  463. strong_x = x()
  464. if strong_x:
  465. strong_x._dev_tensor()
  466. def _apply_graph_options(self, graph):
  467. graph.options.no_force_inplace = True
  468. graph.options.seq_opt.enable_seq_comp_node_opt = False
  469. graph.options.graph_opt_level = self._graph_opt_level
  470. if self._dtr_config is not None:
  471. graph.options.enable_dtr_memory_opt = True
  472. graph.options.dtr_config.eviction_threshold = (
  473. self._dtr_config.eviction_threshold
  474. )
  475. graph.options.dtr_config.evictee_minimum_size = (
  476. self._dtr_config.evictee_minimum_size
  477. )
  478. graph.options.dtr_config.recomp_memory_factor = (
  479. self._dtr_config.recomp_memory_factor
  480. )
  481. graph.options.dtr_config.recomp_time_factor = (
  482. self._dtr_config.recomp_time_factor
  483. )
  484. # graph optimization
  485. if self._graph_opt_config is not None:
  486. mapping = {None: 0, False: 1, True: 2}
  487. jit_config = graph.options.graph_opt.jit_config
  488. jit_config.fuse_dimshuffle = mapping[
  489. self._graph_opt_config.jit_fuse_dimshuffle
  490. ]
  491. jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce]
  492. # sublinear
  493. if self._sublinear_memory_config is not None:
  494. graph.options.enable_sublinear_memory_opt = True
  495. sublinear_config = graph.options.sublinear_mem_config
  496. sublinear_config.lb_memory_mb = self._sublinear_memory_config.lb_memory_mb
  497. sublinear_config.genetic_nr_iter = (
  498. self._sublinear_memory_config.genetic_nr_iter
  499. )
  500. sublinear_config.genetic_pool_size = (
  501. self._sublinear_memory_config.genetic_pool_size
  502. )
  503. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  504. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  505. # profile
  506. if self._profiling:
  507. self._profiler = GraphProfiler(graph)
  508. self._profiler2 = None
  509. if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
  510. graph.options.var_sanity_check_first_run = False
  511. def _execute_graph(self, graph: G.Graph, *args):
  512. if is_profiling() and (self._profiler2 is None):
  513. self._profiler2 = GraphProfiler2(graph)
  514. elif not is_profiling() and (self._profiler2 is not None):
  515. self._profiler2 = None
  516. graph.execute(*args)
  517. def _compile(self):
  518. graph = self._graph = G.Graph()
  519. graph.options.async_exec_level = 0b100
  520. self._apply_graph_options(graph)
  521. need_reset_nodes = self._need_reset_nodes = []
  522. # links enforce ordering of I/O nodes
  523. in_out_links = ()
  524. io_links = ()
  525. readers = []
  526. if self._capture_as_const:
  527. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  528. info = self._tinfo[h]
  529. opnode = info.data_setter = G.InputNode(
  530. device=info.device,
  531. dtype=info.dtype,
  532. shape=info.shape or (1,),
  533. graph=graph,
  534. use_static_shape=_input_node_use_static_shape(),
  535. )
  536. need_reset_nodes.append(opnode)
  537. info.varnode = opnode.outputs[0]
  538. in_out_links += opnode.outputs[1:]
  539. for op, ihandles, ohandles in self._seq:
  540. if isinstance(op, str) and op == "Const":
  541. assert len(ihandles) == 0
  542. (h,) = ohandles
  543. info = self._tinfo[h]
  544. if not hasattr(info, "varnode"):
  545. assert info.external
  546. assert info.bound_data
  547. info.varnode = graph.make_const(
  548. info.bound_data.numpy(),
  549. info.bound_data.dtype,
  550. info.bound_data.device,
  551. )
  552. continue
  553. require_links = type(op) in _io_op_types
  554. ivars = []
  555. for i, h in enumerate(ihandles):
  556. info = self._tinfo[h]
  557. if not hasattr(info, "varnode"):
  558. assert info.external
  559. if info.bound_data:
  560. if getattr(info, "is_const", False):
  561. info.varnode = graph.make_const(
  562. info.bound_data.numpy(),
  563. info.bound_data.dtype,
  564. info.bound_data.device,
  565. )
  566. else:
  567. info.varnode = graph.make_const(
  568. info.bound_data._dev_tensor()
  569. # info.bound_data.numpy()
  570. )
  571. else:
  572. opnode = info.data_setter = G.InputNode(
  573. *in_out_links,
  574. device=info.device,
  575. dtype=info.dtype,
  576. shape=info.shape or (1,),
  577. graph=graph,
  578. use_static_shape=_input_node_use_static_shape(),
  579. )
  580. need_reset_nodes.append(opnode)
  581. info.varnode, *in_out_links = opnode.outputs
  582. if require_links and i == 0 and len(io_links) > 0:
  583. opnode = G.VirtualDepNode(
  584. [info.varnode, *io_links], str(io_links[0].device)
  585. )
  586. info.varnode = opnode.outputs[0]
  587. io_links = (info.varnode,)
  588. ivars.append(info.varnode)
  589. ovars = G.apply_normal_varnode(op, *ivars)
  590. if require_links and len(ovars) > 0:
  591. io_links = (ovars[0],)
  592. assert len(ovars) == len(ohandles)
  593. for h, v in zip(ohandles, ovars):
  594. info = self._tinfo[h]
  595. info.varnode = v
  596. def add_reader(opnode):
  597. nonlocal in_out_links
  598. need_reset_nodes.append(opnode)
  599. readers.append(opnode.outputs[0])
  600. in_out_links = opnode.outputs
  601. if info.data_read:
  602. # Shape can be obtained from data so doesn't need its own
  603. # output node. On the other hand, value is read separately
  604. # to leverage eager h2d copy
  605. info.shape_read = False
  606. opnode = info.data_reader = G.OutputNode(v, *in_out_links)
  607. add_reader(opnode)
  608. if info.value_read:
  609. opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
  610. add_reader(opnode)
  611. if info.shape_read:
  612. opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
  613. add_reader(opnode)
  614. graph.options.graph_opt_level = self._graph_opt_level
  615. graph._set_priority_to_id([*readers, *in_out_links, *io_links])
  616. graph.compile(*readers, *in_out_links, *io_links)
  617. def _reset_exec_env(self):
  618. for opnode in self._need_reset_nodes:
  619. opnode.reset()
  620. def __call__(self, *args, **kwargs):
  621. with self._setup():
  622. if self._capture_as_const:
  623. self._process_inputs(*args, **kwargs)
  624. outputs = self.__wrapped__(*args, **kwargs)
  625. if self._capture_as_const:
  626. self._process_outputs(outputs)
  627. return outputs
  628. def dump(
  629. self,
  630. file,
  631. *,
  632. arg_names=None,
  633. output_names=None,
  634. append=False,
  635. keep_var_name: int = 1,
  636. keep_opr_name: bool = False,
  637. keep_param_name: bool = False,
  638. keep_opr_priority: bool = False,
  639. strip_info_file=None,
  640. append_json=False,
  641. optimize_for_inference=True,
  642. user_info: Any = None,
  643. enable_metadata: bool = True,
  644. **kwargs
  645. ):
  646. r"""Serializes trace to file system.
  647. Args:
  648. file: output file, could be file object or filename.
  649. arg_names: names of the input tensors in the traced function.
  650. output_names: names of the output tensors in the traced function,
  651. use the default name if not specified.
  652. append: whether output is appended to ``file``.
  653. Only works when ``file`` is str.
  654. keep_var_name: level for keeping variable names:
  655. * 0: none of the names are kept
  656. * 1: (default)keep names of output vars
  657. * 2: keep names of all (output and internal) vars
  658. keep_opr_name: whether to keep operator names.
  659. keep_param_name: whether to keep param names, so param values can be
  660. easily manipulated after loading model
  661. keep_opr_priority: whether to keep priority setting for operators
  662. strip_info_file: a string for path or a file handler. if is not None,
  663. then the dump information for code strip would be written to ``strip_info_file``
  664. append_json: will be check when `strip_info_file` is not None. if set
  665. true, the information for code strip will be append to strip_info_file.
  666. if set false, will rewrite strip_info_file
  667. optimize_for_inference: enbale optmizations,
  668. will skip all optimize options if this is False. Default: True
  669. user_info: any type object, which will be pickled to bytes.
  670. enable_metadata: whether to save metadata into output file.
  671. Keyword Arguments:
  672. * enable_io16xc32 --
  673. whether to use float16 for I/O between oprs and use
  674. float32 as internal computation precision. Note the output var would be
  675. changed to float16.
  676. * enable_ioc16 --
  677. whether to use float16 for both I/O and computation
  678. precision.
  679. * enable_hwcd4 --
  680. whether to use NHWCD4 data layout. This is faster on some
  681. OpenCL backend.
  682. * enable_nchw88 --
  683. whether to use NCHW88 data layout, currently
  684. used in X86 AVX backend.
  685. * enable_nchw44 --
  686. whether to use NCHW44 data layout, currently
  687. used in arm backend.
  688. * enable_nchw44_dot --
  689. whether to use NCHW44_dot data layout, currently
  690. used in armv8.2+dotprod backend.
  691. * enable_nchw4 --
  692. whether to use NCHW4 data layout, currently
  693. used in nvidia backend(based on cudnn).
  694. * enable_nchw32 --
  695. whether to use NCHW32 data layout, currently
  696. used in nvidia backend with tensorcore(based on cudnn).
  697. * enable_chwn4 --
  698. whether to use CHWN4 data layout, currently
  699. used in nvidia backend with tensorcore.
  700. * enable_nchw64 --
  701. whether to use NCHW64 data layout, used for fast int4
  702. support on Nvidia GPU.
  703. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  704. into one opr.
  705. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  706. input for inference on nvidia backend(this optimization pass will
  707. result in mismatch of the precision of output of training and
  708. inference)
  709. """
  710. if not self._capture_as_const:
  711. raise ValueError(
  712. "you must specify capture_as_const=True at __init__ to use dump"
  713. )
  714. if self._untraced and len(self._seq) == 0:
  715. raise RuntimeError("should do record first before dump")
  716. if self._output_names and output_names:
  717. raise TypeError(
  718. "cannot specify output_names when output is already in dict format"
  719. )
  720. if output_names and not isinstance(output_names, collections.abc.Sequence):
  721. output_names = (output_names,)
  722. if output_names and len(output_names) != len(self._output_bindings):
  723. raise ValueError(
  724. "wrong number of output_names, should be {} values".format(
  725. len(self._output_bindings)
  726. )
  727. )
  728. without_arg_names = arg_names is None
  729. if without_arg_names:
  730. arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
  731. if arg_names and not isinstance(arg_names, collections.abc.Sequence):
  732. arg_names = (arg_names,)
  733. if arg_names and len(arg_names) != len(self._arg_bindings):
  734. raise ValueError(
  735. "wrong number of arg_names, should be {} values".format(
  736. len(self._arg_bindings)
  737. )
  738. )
  739. output_names = output_names or self._output_names
  740. def dumped_device(info):
  741. device_name = info.device.logical_name
  742. if device_name[:3] in ("cpu", "gpu", "xpu"):
  743. return as_device("xpux")
  744. return info.device
  745. h2v = {}
  746. graph = G.Graph()
  747. # apply graph_opt_level in dump
  748. if self._graph_opt_level is not None:
  749. graph.options.graph_opt_level = self._graph_opt_level
  750. for i, h in enumerate(self._arg_bindings):
  751. info = self._tinfo[h]
  752. h2v[h] = graph.make_h2d(
  753. dtype=info.dtype,
  754. device=dumped_device(info),
  755. shape=info.shape or (1,),
  756. name=info.name if without_arg_names and info.name else arg_names[i],
  757. )
  758. for k, h in self._kwarg_bindings.items():
  759. info = self._tinfo[h]
  760. h2v[h] = graph.make_h2d(
  761. dtype=info.dtype,
  762. device=dumped_device(info),
  763. shape=info.shape or (1,),
  764. name=k,
  765. )
  766. for op, ihandles, ohandles in self._seq:
  767. if isinstance(op, str) and op == "Const":
  768. assert len(ihandles) == 0
  769. (h,) = ohandles
  770. info = self._tinfo[h]
  771. if h not in h2v:
  772. assert info.external
  773. assert info.bound_data
  774. h2v[h] = graph.make_const(
  775. info.bound_data.numpy(),
  776. dtype=info.dtype,
  777. device=dumped_device(info),
  778. name=info.name,
  779. )
  780. continue
  781. ivars = []
  782. for h in ihandles:
  783. info = self._tinfo[h]
  784. if h not in h2v:
  785. assert info.external
  786. assert info.bound_data
  787. h2v[h] = graph.make_const(
  788. info.bound_data.numpy(),
  789. dtype=info.dtype,
  790. device=dumped_device(info),
  791. name=info.name,
  792. )
  793. ivars.append(h2v[h])
  794. if isinstance(op, BatchNorm):
  795. assert (
  796. op.fwd_mode == BatchNorm.FwdMode.INFERENCE
  797. ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
  798. ovars = G.apply_normal_varnode(op, *ivars)
  799. AutoNaming.record_opnode(ovars[0].op)
  800. assert len(ovars) == len(ohandles)
  801. h2v.update(zip(ohandles, ovars))
  802. for i in ohandles:
  803. name = AutoNaming.get_var_name(i)
  804. if name is not None:
  805. h2v[i].name = name
  806. AutoNaming.remove_duplicate_names()
  807. dest_vars = []
  808. for i, h in enumerate(self._output_bindings):
  809. v = h2v[h]
  810. if output_names:
  811. v.name = output_names[i]
  812. dest_vars.append(v)
  813. if optimize_for_inference:
  814. dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)
  815. metadata = SerializationMetadata()
  816. if enable_metadata:
  817. metadata.user_info = pickle.dumps(user_info)
  818. metadata.is_valid = True
  819. metadata.graph_modified = False
  820. if optimize_for_inference:
  821. metadata.optimize_options = optimize_options
  822. if isinstance(file, str):
  823. permission = "wb" if append == False else "ab"
  824. file = open(file, permission)
  825. if keep_opr_priority:
  826. graph._set_priority_to_id(dest_vars)
  827. dump_content, dump_info = G.dump_graph(
  828. dest_vars,
  829. keep_var_name=keep_var_name,
  830. keep_opr_name=keep_opr_name,
  831. keep_param_name=keep_param_name,
  832. keep_opr_priority=keep_opr_priority,
  833. strip_info_file=strip_info_file,
  834. append_json=append_json,
  835. metadata=metadata,
  836. )
  837. file.write(dump_content)
  838. return dump_info
  839. def _process_inputs(self, *args, **kwargs):
  840. if self._untraced:
  841. self._inputs_to_restore = []
  842. def record_input(x):
  843. if x is None:
  844. return
  845. h, info = self._new_handle()
  846. info.external = False
  847. info.name = x.c_name
  848. info.device = x.device
  849. info.dtype = x.dtype
  850. info.shape = x.numpy().shape
  851. x._mixin_handle = h
  852. x._recording = True
  853. x._trace_mixin_info = info
  854. self._inputs_to_restore.append(x)
  855. return h
  856. self._arg_bindings = []
  857. for i, x in enumerate(args):
  858. if not isinstance(x, RawTensor):
  859. raise TypeError(
  860. "positional arguments should all be tensor "
  861. "but args[%d] cannot be recognized as one" % i
  862. )
  863. self._arg_bindings.append(record_input(x))
  864. self._kwarg_bindings = {}
  865. for k, x in kwargs.items():
  866. if isinstance(x, RawTensor):
  867. self._kwarg_bindings[k] = record_input(x)
  868. else:
  869. if len(args) != len(self._arg_bindings):
  870. raise TraceMismatchError("positional argument length mismatch")
  871. self._tensor_remaps = {}
  872. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  873. if not isinstance(x, RawTensor):
  874. raise TypeError(
  875. "positional arguments should all be tensor "
  876. "but args[%d] cannot be recognized as one" % i
  877. )
  878. info = self._tinfo[h]
  879. if x.dtype != info.dtype:
  880. raise TypeError("args[%d].dtype different from last time" % i)
  881. if x.device != info.device:
  882. raise TypeError("args[%d].device different from last time" % i)
  883. info.data_setter.set_value(x._dev_tensor())
  884. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  885. kwargs_tensors = {}
  886. for k, x in kwargs.items():
  887. if isinstance(x, RawTensor):
  888. kwargs_tensors[k] = x
  889. if set(kwargs_tensors) != set(self._kwarg_bindings):
  890. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  891. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  892. if too_many:
  893. raise TraceMismatchError(
  894. "keyword arguments found to be tensor this time "
  895. "but were non-tensor previously: %s" % " ".join(too_many)
  896. )
  897. if too_few:
  898. raise TraceMismatchError(
  899. "keyword arguments found to be non-tensor this time "
  900. "but were tensor previously: %s" % " ".join(too_few)
  901. )
  902. for k, h in self._kwarg_bindings.items():
  903. x = kwargs_tensors[k]
  904. info = self._tinfo[h]
  905. if x.dtype != info.dtype:
  906. raise TypeError("kwargs[%s].dtype different from last time" % k)
  907. if x.device != info.device:
  908. raise TypeError("kwargs[%s].device different from last time" % k)
  909. info.data_setter.set_value(x._dev_tensor())
  910. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  911. def _process_outputs(self, outputs):
  912. output_names = None
  913. if isinstance(outputs, collections.abc.Mapping):
  914. output_names, outputs = zip(*sorted(outputs.items()))
  915. elif not isinstance(outputs, collections.abc.Sequence):
  916. outputs = (outputs,)
  917. if not self._untraced:
  918. if output_names != self._output_names:
  919. too_many = set(output_names) - set(self._output_names)
  920. too_few = set(self._output_names) - set(output_names)
  921. if too_many:
  922. raise TraceMismatchError(
  923. "output has more keys than last time: %s" % " ".join(too_many)
  924. )
  925. if too_few:
  926. raise TraceMismatchError(
  927. "output has less keys than last time: %s" % " ".join(too_few)
  928. )
  929. if len(outputs) != len(self._output_bindings):
  930. raise TraceMismatchError("output size differs from last time")
  931. else:
  932. self._output_names = output_names
  933. self._output_bindings = []
  934. for i, x in enumerate(outputs):
  935. if not isinstance(x, RawTensor):
  936. raise TypeError("every item of return value should be tensor")
  937. if self._untraced:
  938. h = x._mixin_handle
  939. if h < 0:
  940. raise RuntimeError("output is not computed from inputs")
  941. self._output_bindings.append(h)
  942. else:
  943. h = x._mixin_handle
  944. if h not in self._output_handles:
  945. raise RuntimeError("output is not computed from inputs")
  946. if h != self._output_bindings[i]:
  947. raise TraceMismatchError(
  948. "retval[%s] is a different tensor than last time"
  949. % (output_names and output_names[i] or i)
  950. )
  951. def get_profile(self):
  952. r"""Get profiling result for compiled trace.
  953. Return:
  954. a json compatible object.
  955. """
  956. if not self._profiler:
  957. raise RuntimeError("trace is not set with profiling=True")
  958. return json.loads(self._profiler.get())
  959. class CompiledTensorProxy:
  960. r"""Duck-typed RawTensor"""
  961. def __init__(self, handle):
  962. self.__handle = handle
  963. self._isscalar = False
  964. self.__info = active_trace._tinfo[handle]
  965. self.__shape = None
  966. self.__data = None
  967. self.__value = None
  968. @property
  969. def dtype(self):
  970. return self.__info.varnode.dtype
  971. @property
  972. def device(self):
  973. return self.__info.varnode.device
  974. @property
  975. def shape(self):
  976. if self._isscalar:
  977. return ()
  978. if self.__shape is None:
  979. if self.__info.shape_read:
  980. self.__shape = self.__info.shape_reader.get_value().shape
  981. elif self.__info.data_read:
  982. self.__shape = self._dev_tensor().shape
  983. else:
  984. # c++ will throw TraceReadError
  985. return None
  986. return self.__shape
  987. def numpy(self):
  988. if self.__value is None:
  989. if self.__info.value_read:
  990. self.__value = self.__info.value_reader.get_value()
  991. elif self.__info.data_read:
  992. self.__value = self._dev_tensor().numpy()
  993. else:
  994. # c++ will throw TraceReadError
  995. return None
  996. # c++ side will handle scalar case
  997. return self.__value
  998. def _dev_tensor(self):
  999. if self.__data is None:
  1000. if not self.__info.data_read:
  1001. # c++ will throw TraceReadError
  1002. return None
  1003. self.__data = self.__info.data_reader.get_value()
  1004. return self.__data
  1005. def __del__(self):
  1006. if self.__info.shape_read and self.__shape is not None:
  1007. self.__info.shape_reader.drop_value()
  1008. if self.__info.value_read and self.__value is not None:
  1009. self.__info.value_reader.drop_value()
  1010. if self.__info.data_read and self.__data is not None:
  1011. self.__info.data_reader.drop_value()
  1012. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  1013. graph = active_trace._lazy_eval_graph
  1014. ivars = []
  1015. for x in args:
  1016. var = getattr(x, "_varnode", None)
  1017. if var:
  1018. ivars.append(var)
  1019. else:
  1020. data_setter = G.InputNode(
  1021. device=x.device,
  1022. dtype=x.dtype,
  1023. shape=x.numpy().shape or (1,),
  1024. graph=graph,
  1025. use_static_shape=True,
  1026. )
  1027. var = data_setter.outputs[0]
  1028. ivars.append(var)
  1029. data_setter.set_value(x._dev_tensor())
  1030. require_links = type(op) in _io_op_types
  1031. if require_links and active_trace._lazy_eval_links:
  1032. assert len(ivars) > 0, "op should has at least one input"
  1033. opnode = G.VirtualDepNode(
  1034. [ivars[0], *active_trace._lazy_eval_links],
  1035. str(active_trace._lazy_eval_links[0].device),
  1036. )
  1037. ivars[0] = opnode.outputs[0]
  1038. active_trace._lazy_eval_links = (ivars[0],)
  1039. ovars = G.apply_normal_varnode(op, *ivars)
  1040. outputs = [RawTensor(o) for o in ovars]
  1041. if require_links:
  1042. active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
  1043. return outputs
  1044. def apply_const_symbolic_mode(value, dtype, device, name):
  1045. graph = active_trace._lazy_eval_graph
  1046. # don't need to unset tracing
  1047. # because varnode construction will ignore tracing flag
  1048. ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name))
  1049. if np.array(value).ndim == 0:
  1050. setscalar(ret)
  1051. return (ret,)
  1052. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  1053. if skip_tracing:
  1054. args = [
  1055. RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  1056. for x in args
  1057. ]
  1058. unset_tracing()
  1059. ret = apply(op, *args)
  1060. set_tracing()
  1061. return ret
  1062. return active_trace._apply_op(op, args)
  1063. def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
  1064. if skip_tracing:
  1065. unset_tracing()
  1066. ret = RawTensor(value, dtype, device, False, name)
  1067. set_tracing()
  1068. return ret
  1069. return active_trace._apply_const(value, dtype, device)
  1070. def apply_with_tracing(op: OpDef, *args: RawTensor):
  1071. if active_trace._graph:
  1072. # if member _graph exits, then is_compiled
  1073. return apply_compiled_mode(op, *args)
  1074. if hasattr(op, "scope"):
  1075. op.scope = AutoNaming.get_scope()
  1076. if active_trace._symbolic:
  1077. outputs = apply_symbolic_mode(op, *args)
  1078. else:
  1079. unset_tracing()
  1080. outputs = apply(op, *args)
  1081. set_tracing()
  1082. active_trace._record_op(op, args, outputs)
  1083. return list(outputs)
  1084. def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
  1085. if active_trace._graph:
  1086. return apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name)
  1087. if active_trace._symbolic:
  1088. outputs = apply_const_symbolic_mode(value, dtype, device, name)
  1089. else:
  1090. unset_tracing()
  1091. outputs = RawTensor(value, dtype, device, False, name)
  1092. if np.array(value).ndim == 0:
  1093. setscalar(outputs)
  1094. outputs = (outputs,)
  1095. set_tracing()
  1096. active_trace._record_const(outputs)
  1097. return list(outputs)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台