You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 44 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import contextlib
  11. import functools
  12. import itertools
  13. import json
  14. import os
  15. import numpy as np
  16. from ..core._imperative_rt import GraphProfiler
  17. from ..core._imperative_rt.core2 import Tensor as RawTensor
  18. from ..core._imperative_rt.core2 import (
  19. TensorWeakRef,
  20. apply,
  21. set_tracing,
  22. skip_tracing,
  23. unset_tracing,
  24. )
  25. from ..core._imperative_rt.ops import (
  26. AssertEqual,
  27. CollectiveComm,
  28. RemoteRecv,
  29. RemoteSend,
  30. )
  31. from ..core._trace_option import set_symbolic_shape
  32. from ..core._wrap import device as as_device
  33. from ..core.ops.builtin import BatchNorm, OpDef
  34. from ..core.ops.special import Const
  35. from ..core.tensor import megbrain_graph as G
  36. from ..core.tensor.utils import setscalar
  37. from ..utils.naming import AutoNaming
  38. from .dtr_config import DTRConfig
  39. from .graph_opt_config import GraphOptimizationConfig
  40. from .sublinear_memory_config import SublinearMemoryConfig
  41. def _input_node_use_static_shape():
  42. return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
  43. class TraceMismatchError(RuntimeError):
  44. pass
  45. active_trace = None
  46. def is_tracing():
  47. if active_trace is None:
  48. return False
  49. else:
  50. return not skip_tracing
  51. @contextlib.contextmanager
  52. def exclude_from_trace():
  53. global skip_tracing
  54. if skip_tracing:
  55. yield
  56. return
  57. try:
  58. skip_tracing = True
  59. unset_tracing()
  60. if active_trace is not None:
  61. active_trace._begin_excluded_region()
  62. yield
  63. finally:
  64. skip_tracing = False
  65. set_tracing()
  66. class TensorInfo:
  67. __slots__ = (
  68. # collected attributes
  69. "name",
  70. "external",
  71. "data_read",
  72. "shape_read",
  73. "value_read",
  74. "exported",
  75. "device",
  76. "dtype",
  77. "shape",
  78. "is_const",
  79. "bound_data",
  80. # resources for execution
  81. "varnode",
  82. "data_setter",
  83. "shape_reader",
  84. "value_reader",
  85. "data_reader",
  86. )
  87. def __init__(self):
  88. self.name = None
  89. self.exported = None
  90. self.data_read = None
  91. self.shape_read = None
  92. self.value_read = None
  93. self.bound_data = None
  94. self.data_setter = None
  95. self.shape_reader = None
  96. self.value_reader = None
  97. self.data_reader = None
  98. _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}
  99. class trace:
  100. """
  101. Wraps a callable and provide:
  102. * tracing via :meth:`.trace` and :meth:`.dump`
  103. * accelerated evalutaion via :meth:`.__call__`
  104. :param function: the function will be traced.
  105. :param symbolic: whether to apply symbolic execution for tracing. Default: False
  106. :param capture_as_const: capture global vars or closures as const value. Default: False
  107. :param sublinear_memory_config: configuration for sublinear memory optimization.
  108. If not None, it enables sublinear memory optimization with given setting.
  109. :param profiling: whether to profile compiled trace. Default: False
  110. :param opt_level: optimization level for compiling trace. Default: 2
  111. :param graph_opt_config: configuration for graph optimization. Default: None
  112. :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
  113. """
  114. def __new__(cls, *args, **kwargs):
  115. if not args:
  116. return functools.partial(cls, **kwargs)
  117. return super().__new__(cls)
  118. def __init__(
  119. self,
  120. function,
  121. symbolic=False,
  122. capture_as_const=False,
  123. sublinear_memory_config: SublinearMemoryConfig = None,
  124. dtr_config: DTRConfig = None,
  125. profiling: bool = False,
  126. opt_level: int = 2,
  127. graph_opt_config: GraphOptimizationConfig = None,
  128. symbolic_shape: bool = True,
  129. ):
  130. self.__wrapped__ = function
  131. self._symbolic = symbolic
  132. self._capture_as_const = capture_as_const
  133. self._sublinear_memory_config = sublinear_memory_config
  134. self._dtr_config = dtr_config
  135. self._profiling = profiling
  136. self._profiler = None
  137. self._graph_opt_level = opt_level
  138. self._graph_opt_config = graph_opt_config
  139. self._symbolic_shape = symbolic_shape
  140. self._output_handles = set()
  141. self._reset()
  142. def _reset(self):
  143. self._untraced = True
  144. self._tinfo = [] # handle -> TensorInfo
  145. self._seq = []
  146. self._pc = 0
  147. self._graph = None
  148. self._need_reset_nodes = None
  149. self._lazy_eval_graph = None
  150. self._lazy_eval_tensors = set()
  151. self._lazy_eval_links = None
  152. self._active_tensors = set()
  153. self._tensor_remaps = None
  154. self._inputs_to_restore = None
  155. self._arg_bindings = None
  156. self._kwarg_bindings = None
  157. self._output_bindings = None
  158. self._output_names = None
  159. def _new_handle(self):
  160. handle = len(self._tinfo)
  161. info = TensorInfo()
  162. self._tinfo.append(info)
  163. return handle, info
  164. def _apply_op(self, op, args):
  165. assert not self._untraced
  166. # check against trace
  167. if self._pc >= len(self._seq):
  168. raise TraceMismatchError("trace should end here, but more op observed")
  169. record = self._seq[self._pc]
  170. op_, ihandles, ohandles = record
  171. if (isinstance(op_, str) and op_ == "Const") or (op != op_):
  172. raise TraceMismatchError("op different from last time")
  173. if len(ihandles) != len(args):
  174. raise TraceMismatchError("op input size different from last time")
  175. # check all inputs of crrent op
  176. for h, x in zip(ihandles, args):
  177. info = self._tinfo[h]
  178. if info.external:
  179. if (
  180. x._compiled_info is not None
  181. and not self._tinfo[x._mixin_handle].exported
  182. ):
  183. raise TraceMismatchError(
  184. "failed to capture: input was an external tensor "
  185. "last time, got an internal tensor this time"
  186. )
  187. if info.bound_data:
  188. if x._compiled_info is not None:
  189. raise TraceMismatchError(
  190. "const capture violated: was an external tensor "
  191. "last time, got an internal tensor this time"
  192. )
  193. if x._handle != info.bound_data._handle:
  194. if not np.array_equal(x.numpy(), info.bound_data.numpy()):
  195. raise TraceMismatchError(
  196. "const capture violated: got "
  197. "a different tensor this time"
  198. )
  199. else:
  200. if info.dtype != x.dtype:
  201. raise TraceMismatchError(
  202. "failed to capture: different dtype from last time"
  203. )
  204. if info.device != x.device:
  205. raise TraceMismatchError(
  206. "failed to capture: different device from last time"
  207. )
  208. info.data_setter.set_value(x._dev_tensor())
  209. else:
  210. if x._mixin_handle == -1:
  211. if x._handle not in self._tensor_remaps:
  212. raise TraceMismatchError(
  213. "unexpected capture: trying to use an external tensor as "
  214. "input, but that input was an internal tensor last time"
  215. )
  216. else:
  217. x._mixin_handle = self._tensor_remaps[
  218. x._handle
  219. ]._CompiledTensorProxy__handle
  220. if x._mixin_handle != h:
  221. raise TraceMismatchError(
  222. "mis-wiring: input edge to an data flow "
  223. "graph node is different from last time"
  224. )
  225. self._pc += 1
  226. outputs = []
  227. for h in ohandles:
  228. info = self._tinfo[h]
  229. # generate output tensor and create compied info
  230. y = RawTensor(info.varnode)
  231. y._compiled_info = CompiledTensorProxy(h)
  232. y._mixin_handle = h
  233. outputs += [y]
  234. self._active_tensors.add(TensorWeakRef(y))
  235. self._output_handles.update(ohandles)
  236. return outputs
  237. def _apply_const(self, value, dtype, device):
  238. assert not self._untraced
  239. # check against trace
  240. if self._pc >= len(self._seq):
  241. raise TraceMismatchError("trace should end here, but more op observed")
  242. record = self._seq[self._pc]
  243. op_, ihandles, ohandles = record
  244. # Const op is represented by a str
  245. assert isinstance(op_, str) and op_ == "Const"
  246. eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
  247. if not eq:
  248. raise TraceMismatchError(
  249. "const tensor violated: got a different tensor this time"
  250. )
  251. self._pc += 1
  252. (h,) = ohandles
  253. outputs = [self._tinfo[h].bound_data]
  254. return outputs
  255. # run in first step, record information for trace
  256. def _record_op(self, op, inputs, outputs):
  257. if skip_tracing:
  258. for x in inputs:
  259. h = getattr(x, "_mixin_handle", -1)
  260. if h >= 0:
  261. self._tinfo[h].data = True
  262. return
  263. ihandles = []
  264. for x in inputs:
  265. h = getattr(x, "_mixin_handle", -1)
  266. if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
  267. h, info = self._new_handle()
  268. name = AutoNaming.gen_name(x)
  269. info.name = name
  270. info.external = True
  271. info.device = x.device
  272. info.dtype = x.dtype
  273. info.shape = x.shape
  274. if self._capture_as_const:
  275. info.bound_data = RawTensor(
  276. x.numpy(), x.dtype, x.device, False, name
  277. )
  278. ihandles.append(h)
  279. ohandles = []
  280. for x in outputs:
  281. h, info = self._new_handle()
  282. ohandles.append(h)
  283. info.external = False
  284. x._mixin_handle = h
  285. x._recording = True
  286. x._trace_mixin_info = info
  287. self._active_tensors.add(TensorWeakRef(x))
  288. if self._symbolic:
  289. self._lazy_eval_tensors.add(TensorWeakRef(x))
  290. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  291. def _record_const(self, outputs):
  292. if skip_tracing:
  293. (x,) = outputs
  294. h = getattr(x, "_mixin_handle", -1)
  295. if h >= 0:
  296. self._tinfo[h].data_read = True
  297. return
  298. (x,) = outputs
  299. h, info = self._new_handle()
  300. ohandles = [h]
  301. info.external = True
  302. info.device = x.device
  303. info.dtype = x.dtype
  304. info.shape = x.shape
  305. info.bound_data = x
  306. info.is_const = True
  307. x._mixin_handle = h
  308. x._recording = True
  309. x._trace_mixin_info = info
  310. if self._symbolic:
  311. self._lazy_eval_tensors.add(TensorWeakRef(x))
  312. self._seq.append(("Const", tuple(), tuple(ohandles)))
  313. def _set_active(self, active: bool):
  314. global active_trace
  315. if active:
  316. if active_trace:
  317. raise NotImplementedError("sorry, not implemented: nested trace")
  318. active_trace = self
  319. else:
  320. assert active_trace is self
  321. active_trace = None
  322. def _init_trace(self, symbolic: bool):
  323. if symbolic:
  324. self._lazy_eval_graph = G.Graph()
  325. self._apply_graph_options(self._lazy_eval_graph)
  326. self._lazy_eval_links = ()
  327. def _take_escaped_tensors(self):
  328. escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
  329. self._active_tensors.clear()
  330. return escaped_tensors
  331. def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
  332. lazy_eval_tensors = [x() for x in lazy_eval_tensors]
  333. lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None]
  334. readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors]
  335. self._apply_graph_options(lazy_eval_graph)
  336. lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
  337. lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
  338. lazy_eval_graph.compile(*lazy_eval_links, *readers)
  339. lazy_eval_graph()
  340. for r, x in zip(readers, lazy_eval_tensors):
  341. # get values from lazy_eval_graph and assign to lazy_eval tensor
  342. x._handle = RawTensor(r.op.get_value())._handle
  343. x._reset_varnode()
  344. @contextlib.contextmanager
  345. def _setup(self):
  346. interrupted = False
  347. def do_enter():
  348. set_tracing()
  349. self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
  350. self._set_active(True)
  351. if self._untraced:
  352. self._init_trace(self._symbolic)
  353. else:
  354. if self._graph is None:
  355. self._compile()
  356. self._graph.execute()
  357. def do_finalize():
  358. escaped_tensors = self._take_escaped_tensors()
  359. if self._untraced:
  360. for x in escaped_tensors:
  361. if x():
  362. info = self._tinfo[x()._mixin_handle]
  363. info.data_read = True
  364. x()._mixin_handle = -1
  365. x()._recording = False
  366. if self._inputs_to_restore:
  367. for x in self._inputs_to_restore:
  368. x._mixin_handle = -1
  369. x._recording = False
  370. if self._symbolic and (
  371. self._lazy_eval_tensors or self._lazy_eval_links
  372. ):
  373. # eval lazy eval tensors
  374. self._lazy_eval(
  375. self._lazy_eval_graph,
  376. self._lazy_eval_tensors,
  377. self._lazy_eval_links,
  378. )
  379. self._lazy_eval_graph = None
  380. self._lazy_eval_tensors = None
  381. self._lazy_eval_links = None
  382. self._untraced = False
  383. else:
  384. # compiled_tensor leaks
  385. if self._pc == len(self._seq):
  386. for x in escaped_tensors:
  387. try:
  388. assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
  389. except RuntimeError:
  390. # TraceMismatchError thrown in do_exit
  391. pass
  392. self._graph.wait()
  393. self._reset_exec_env()
  394. # reset status
  395. self._pc = 0
  396. self._tensor_remaps = None
  397. self._set_active(False)
  398. set_symbolic_shape(self._save_symbolic_shape)
  399. unset_tracing()
  400. def do_exit():
  401. unset_tracing()
  402. if not self._untraced and self._pc != len(self._seq):
  403. raise TraceMismatchError("premature end")
  404. if not self._symbolic or not self._untraced:
  405. # reset output tensors
  406. for x in self._active_tensors.copy():
  407. strong_x = x()
  408. if strong_x is not None:
  409. strong_x._dev_tensor()
  410. strong_x._reset_varnode()
  411. strong_x._mixin_handle = -1
  412. strong_x._recording = False
  413. strong_x._trace_mixin_info = None
  414. try:
  415. do_enter()
  416. yield
  417. do_exit()
  418. except:
  419. interrupted = True
  420. raise
  421. finally:
  422. do_finalize()
  423. if interrupted:
  424. self._reset()
  425. def _begin_excluded_region(self):
  426. if self._capture_as_const:
  427. raise RuntimeError(
  428. "exclude_from_trace cannot be used with capture_as_const"
  429. )
  430. if self._untraced:
  431. # conditionally reading a compiled tensor in excluded region
  432. # is permitted, so we have to assume every tensor might be read
  433. for x in self._active_tensors:
  434. strong_x = x()
  435. if strong_x:
  436. info = self._tinfo[strong_x._mixin_handle]
  437. info.exported = True
  438. info.data_read = True
  439. else:
  440. for x in self._active_tensors:
  441. strong_x = x()
  442. if strong_x:
  443. strong_x._dev_tensor()
  444. def _apply_graph_options(self, graph):
  445. graph.options.no_force_inplace = True
  446. graph.options.seq_opt.enable_seq_comp_node_opt = False
  447. graph.options.graph_opt_level = self._graph_opt_level
  448. if self._dtr_config is not None:
  449. graph.options.enable_dtr_memory_opt = True
  450. graph.options.dtr_config.eviction_threshold = (
  451. self._dtr_config.eviction_threshold
  452. )
  453. graph.options.dtr_config.evictee_minimum_size = (
  454. self._dtr_config.evictee_minimum_size
  455. )
  456. # graph optimization
  457. if self._graph_opt_config is not None:
  458. mapping = {None: 0, False: 1, True: 2}
  459. jit_config = graph.options.graph_opt.jit_config
  460. jit_config.fuse_dimshuffle = mapping[
  461. self._graph_opt_config.jit_fuse_dimshuffle
  462. ]
  463. jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce]
  464. # sublinear
  465. if self._sublinear_memory_config is not None:
  466. graph.options.enable_sublinear_memory_opt = True
  467. sublinear_config = graph.options.sublinear_mem_config
  468. sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
  469. sublinear_config.genetic_nr_iter = (
  470. self._sublinear_memory_config.genetic_nr_iter
  471. )
  472. sublinear_config.genetic_pool_size = (
  473. self._sublinear_memory_config.genetic_pool_size
  474. )
  475. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  476. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  477. # profile
  478. if self._profiling:
  479. self._profiler = GraphProfiler(graph)
  480. if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
  481. graph.options.var_sanity_check_first_run = False
  482. def _compile(self):
  483. graph = self._graph = G.Graph()
  484. graph.options.async_exec_level = 0b100
  485. self._apply_graph_options(graph)
  486. need_reset_nodes = self._need_reset_nodes = []
  487. # links enforce ordering of I/O nodes
  488. in_out_links = ()
  489. io_links = ()
  490. readers = []
  491. if self._capture_as_const:
  492. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  493. info = self._tinfo[h]
  494. opnode = info.data_setter = G.InputNode(
  495. device=info.device,
  496. dtype=info.dtype,
  497. shape=info.shape or (1,),
  498. graph=graph,
  499. use_static_shape=_input_node_use_static_shape(),
  500. )
  501. need_reset_nodes.append(opnode)
  502. info.varnode = opnode.outputs[0]
  503. in_out_links += opnode.outputs[1:]
  504. for op, ihandles, ohandles in self._seq:
  505. if isinstance(op, str) and op == "Const":
  506. assert len(ihandles) == 0
  507. (h,) = ohandles
  508. info = self._tinfo[h]
  509. if not hasattr(info, "varnode"):
  510. assert info.external
  511. assert info.bound_data
  512. info.varnode = graph.make_const(
  513. info.bound_data.numpy(),
  514. info.bound_data.dtype,
  515. info.bound_data.device,
  516. )
  517. continue
  518. require_links = type(op) in _io_op_types
  519. ivars = []
  520. for i, h in enumerate(ihandles):
  521. info = self._tinfo[h]
  522. if not hasattr(info, "varnode"):
  523. assert info.external
  524. if info.bound_data:
  525. if getattr(info, "is_const", False):
  526. info.varnode = graph.make_const(
  527. info.bound_data.numpy(),
  528. info.bound_data.dtype,
  529. info.bound_data.device,
  530. )
  531. else:
  532. info.varnode = graph.make_const(
  533. info.bound_data._dev_tensor()
  534. # info.bound_data.numpy()
  535. )
  536. else:
  537. opnode = info.data_setter = G.InputNode(
  538. *in_out_links,
  539. device=info.device,
  540. dtype=info.dtype,
  541. shape=info.shape or (1,),
  542. graph=graph,
  543. use_static_shape=_input_node_use_static_shape(),
  544. )
  545. need_reset_nodes.append(opnode)
  546. info.varnode, *in_out_links = opnode.outputs
  547. if require_links and i == 0 and len(io_links) > 0:
  548. opnode = G.VirtualDepNode(
  549. [info.varnode, *io_links], str(io_links[0].device)
  550. )
  551. info.varnode = opnode.outputs[0]
  552. io_links = (info.varnode,)
  553. ivars.append(info.varnode)
  554. ovars = G.apply_normal_varnode(op, *ivars)
  555. if require_links and len(ovars) > 0:
  556. io_links = (ovars[0],)
  557. assert len(ovars) == len(ohandles)
  558. for h, v in zip(ohandles, ovars):
  559. info = self._tinfo[h]
  560. info.varnode = v
  561. def add_reader(opnode):
  562. nonlocal in_out_links
  563. need_reset_nodes.append(opnode)
  564. readers.append(opnode.outputs[0])
  565. in_out_links = opnode.outputs
  566. if info.data_read:
  567. # Shape can be obtained from data so doesn't need its own
  568. # output node. On the other hand, value is read separately
  569. # to leverage eager h2d copy
  570. info.shape_read = False
  571. opnode = info.data_reader = G.OutputNode(v, *in_out_links)
  572. add_reader(opnode)
  573. if info.value_read:
  574. opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
  575. add_reader(opnode)
  576. if info.shape_read:
  577. opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
  578. add_reader(opnode)
  579. graph.options.graph_opt_level = self._graph_opt_level
  580. graph._set_priority_to_id([*readers, *in_out_links, *io_links])
  581. graph.compile(*readers, *in_out_links, *io_links)
  582. def _reset_exec_env(self):
  583. for opnode in self._need_reset_nodes:
  584. opnode.reset()
  585. def __call__(self, *args, **kwargs):
  586. with self._setup():
  587. if self._capture_as_const:
  588. self._process_inputs(*args, **kwargs)
  589. outputs = self.__wrapped__(*args, **kwargs)
  590. if self._capture_as_const:
  591. self._process_outputs(outputs)
  592. return outputs
  593. def dump(
  594. self,
  595. file,
  596. *,
  597. arg_names=None,
  598. output_names=None,
  599. append=False,
  600. keep_var_name: int = 1,
  601. keep_opr_name: bool = False,
  602. keep_param_name: bool = False,
  603. keep_opr_priority: bool = False,
  604. strip_info_file=None,
  605. append_json=False,
  606. optimize_for_inference=True,
  607. **kwargs
  608. ):
  609. r"""
  610. Serializes trace to file system.
  611. :param file: output file, could be file object or filename.
  612. :param arg_names: names of the input tensors in the traced function.
  613. :param output_names: names of the output tensors in the traced function,
  614. use the default name if not specified.
  615. :param append: whether output is appended to ``file``.
  616. Only works when ``file`` is str.
  617. :param keep_var_name: level for keeping variable names:
  618. * 0: none of the names are kept
  619. * 1: (default)keep names of output vars
  620. * 2: keep names of all (output and internal) vars
  621. :param keep_opr_name: whether to keep operator names.
  622. :param keep_param_name: whether to keep param names, so param values can be
  623. easily manipulated after loading model
  624. :param keep_opr_priority: whether to keep priority setting for operators
  625. :param strip_info_file: a string for path or a file handler. if is not None,
  626. then the dump information for code strip would be written to ``strip_info_file``
  627. :param append_json: will be check when `strip_info_file` is not None. if set
  628. true, the information for code strip will be append to strip_info_file.
  629. if set false, will rewrite strip_info_file
  630. :param optimize_for_inference: enbale optmizations,
  631. will skip all optimize options if this is False. Default: True
  632. :Keyword Arguments:
  633. * enable_io16xc32 --
  634. whether to use float16 for I/O between oprs and use
  635. float32 as internal computation precision. Note the output var would be
  636. changed to float16.
  637. * enable_ioc16 --
  638. whether to use float16 for both I/O and computation
  639. precision.
  640. * enable_hwcd4 --
  641. whether to use NHWCD4 data layout. This is faster on some
  642. OpenCL backend.
  643. * enable_nchw88 --
  644. whether to use NCHW88 data layout, currently
  645. used in X86 AVX backend.
  646. * enable_nchw44 --
  647. whether to use NCHW44 data layout, currently
  648. used in arm backend.
  649. * enable_nchw44_dot --
  650. whether to use NCHW44_dot data layout, currently
  651. used in armv8.2+dotprod backend.
  652. * enable_nchw4 --
  653. whether to use NCHW4 data layout, currently
  654. used in nvidia backend(based on cudnn).
  655. * enable_nchw32 --
  656. whether to use NCHW32 data layout, currently
  657. used in nvidia backend with tensorcore(based on cudnn).
  658. * enable_chwn4 --
  659. whether to use CHWN4 data layout, currently
  660. used in nvidia backend with tensorcore.
  661. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  662. into one opr.
  663. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  664. input for inference on nvidia backend(this optimization pass will
  665. result in mismatch of the precision of output of training and
  666. inference)
  667. """
  668. if not self._capture_as_const:
  669. raise ValueError(
  670. "you must specify capture_as_const=True at __init__ to use dump"
  671. )
  672. if self._untraced:
  673. raise RuntimeError("should run at least once before calling dump")
  674. if self._output_names and output_names:
  675. raise TypeError(
  676. "cannot specify output_names when output is already in dict format"
  677. )
  678. if output_names and not isinstance(output_names, collections.abc.Sequence):
  679. output_names = (output_names,)
  680. if output_names and len(output_names) != len(self._output_bindings):
  681. raise ValueError(
  682. "wrong number of output_names, should be {} values".format(
  683. len(self._output_bindings)
  684. )
  685. )
  686. without_arg_names = arg_names is None
  687. if without_arg_names:
  688. arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
  689. if arg_names and not isinstance(arg_names, collections.abc.Sequence):
  690. arg_names = (arg_names,)
  691. if arg_names and len(arg_names) != len(self._arg_bindings):
  692. raise ValueError(
  693. "wrong number of arg_names, should be {} values".format(
  694. len(self._arg_bindings)
  695. )
  696. )
  697. output_names = output_names or self._output_names
  698. def dumped_device(info):
  699. device_name = info.device.logical_name
  700. if device_name[:3] in ("cpu", "gpu", "xpu"):
  701. return as_device("xpux")
  702. return info.device
  703. h2v = {}
  704. graph = G.Graph()
  705. # apply graph_opt_level in dump
  706. if self._graph_opt_level is not None:
  707. graph.options.graph_opt_level = self._graph_opt_level
  708. for i, h in enumerate(self._arg_bindings):
  709. info = self._tinfo[h]
  710. h2v[h] = graph.make_h2d(
  711. dtype=info.dtype,
  712. device=dumped_device(info),
  713. shape=info.shape or (1,),
  714. name=info.name if without_arg_names and info.name else arg_names[i],
  715. )
  716. for k, h in self._kwarg_bindings.items():
  717. info = self._tinfo[h]
  718. h2v[h] = graph.make_h2d(
  719. dtype=info.dtype,
  720. device=dumped_device(info),
  721. shape=info.shape or (1,),
  722. name=k,
  723. )
  724. for op, ihandles, ohandles in self._seq:
  725. if isinstance(op, str) and op == "Const":
  726. assert len(ihandles) == 0
  727. (h,) = ohandles
  728. info = self._tinfo[h]
  729. if h not in h2v:
  730. assert info.external
  731. assert info.bound_data
  732. h2v[h] = graph.make_const(
  733. info.bound_data.numpy(),
  734. dtype=info.dtype,
  735. device=info.device,
  736. name=info.name,
  737. )
  738. continue
  739. ivars = []
  740. for h in ihandles:
  741. info = self._tinfo[h]
  742. if h not in h2v:
  743. assert info.external
  744. assert info.bound_data
  745. h2v[h] = graph.make_const(
  746. info.bound_data.numpy(),
  747. dtype=info.dtype,
  748. device=dumped_device(info),
  749. name=info.name,
  750. )
  751. ivars.append(h2v[h])
  752. if isinstance(op, BatchNorm):
  753. assert (
  754. op.fwd_mode == BatchNorm.FwdMode.INFERENCE
  755. ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
  756. ovars = G.apply_normal_varnode(op, *ivars)
  757. AutoNaming.record_opnode(ovars[0].op)
  758. assert len(ovars) == len(ohandles)
  759. h2v.update(zip(ohandles, ovars))
  760. for i in ohandles:
  761. name = AutoNaming.get_var_name(i)
  762. if name is not None:
  763. h2v[i].name = name
  764. AutoNaming.remove_duplicate_names()
  765. dest_vars = []
  766. for i, h in enumerate(self._output_bindings):
  767. v = h2v[h]
  768. if output_names:
  769. v.name = output_names[i]
  770. dest_vars.append(v)
  771. if optimize_for_inference:
  772. dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
  773. if isinstance(file, str):
  774. permission = "wb" if append == False else "ab"
  775. file = open(file, permission)
  776. dump_content, dump_info = G.dump_graph(
  777. dest_vars,
  778. keep_var_name=keep_var_name,
  779. keep_opr_name=keep_opr_name,
  780. keep_param_name=keep_param_name,
  781. keep_opr_priority=keep_opr_priority,
  782. strip_info_file=strip_info_file,
  783. append_json=append_json,
  784. )
  785. file.write(dump_content)
  786. return dump_info
  787. def _process_inputs(self, *args, **kwargs):
  788. if self._untraced:
  789. self._inputs_to_restore = []
  790. def record_input(x):
  791. if x is None:
  792. return
  793. h, info = self._new_handle()
  794. info.external = False
  795. info.name = x.c_name
  796. info.device = x.device
  797. info.dtype = x.dtype
  798. info.shape = x.numpy().shape
  799. x._mixin_handle = h
  800. x._recording = True
  801. x._trace_mixin_info = info
  802. self._inputs_to_restore.append(x)
  803. return h
  804. self._arg_bindings = []
  805. for i, x in enumerate(args):
  806. if not isinstance(x, RawTensor):
  807. raise TypeError(
  808. "positional arguments should all be tensor "
  809. "but args[%d] cannot be recognized as one" % i
  810. )
  811. self._arg_bindings.append(record_input(x))
  812. self._kwarg_bindings = {}
  813. for k, x in kwargs.items():
  814. if isinstance(x, RawTensor):
  815. self._kwarg_bindings[k] = record_input(x)
  816. else:
  817. if len(args) != len(self._arg_bindings):
  818. raise TraceMismatchError("positional argument length mismatch")
  819. self._tensor_remaps = {}
  820. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  821. if not isinstance(x, RawTensor):
  822. raise TypeError(
  823. "positional arguments should all be tensor "
  824. "but args[%d] cannot be recognized as one" % i
  825. )
  826. info = self._tinfo[h]
  827. if x.dtype != info.dtype:
  828. raise TypeError("args[%d].dtype different from last time" % i)
  829. if x.device != info.device:
  830. raise TypeError("args[%d].device different from last time" % i)
  831. info.data_setter.set_value(x._dev_tensor())
  832. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  833. kwargs_tensors = {}
  834. for k, x in kwargs.items():
  835. if isinstance(x, RawTensor):
  836. kwargs_tensors[k] = x
  837. if set(kwargs_tensors) != set(self._kwarg_bindings):
  838. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  839. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  840. if too_many:
  841. raise TraceMismatchError(
  842. "keyword arguments found to be tensor this time "
  843. "but were non-tensor previously: %s" % " ".join(too_many)
  844. )
  845. if too_few:
  846. raise TraceMismatchError(
  847. "keyword arguments found to be non-tensor this time "
  848. "but were tensor previously: %s" % " ".join(too_few)
  849. )
  850. for k, h in self._kwarg_bindings.items():
  851. x = kwargs_tensors[k]
  852. info = self._tinfo[h]
  853. if x.dtype != info.dtype:
  854. raise TypeError("kwargs[%s].dtype different from last time" % k)
  855. if x.device != info.device:
  856. raise TypeError("kwargs[%s].device different from last time" % k)
  857. info.data_setter.set_value(x._dev_tensor())
  858. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  859. def _process_outputs(self, outputs):
  860. output_names = None
  861. if isinstance(outputs, collections.abc.Mapping):
  862. output_names, outputs = zip(*sorted(outputs.items()))
  863. elif not isinstance(outputs, collections.abc.Sequence):
  864. outputs = (outputs,)
  865. if not self._untraced:
  866. if output_names != self._output_names:
  867. too_many = set(output_names) - set(self._output_names)
  868. too_few = set(self._output_names) - set(output_names)
  869. if too_many:
  870. raise TraceMismatchError(
  871. "output has more keys than last time: %s" % " ".join(too_many)
  872. )
  873. if too_few:
  874. raise TraceMismatchError(
  875. "output has less keys than last time: %s" % " ".join(too_few)
  876. )
  877. if len(outputs) != len(self._output_bindings):
  878. raise TraceMismatchError("output size differs from last time")
  879. else:
  880. self._output_names = output_names
  881. self._output_bindings = []
  882. for i, x in enumerate(outputs):
  883. if not isinstance(x, RawTensor):
  884. raise TypeError("every item of return value should be tensor")
  885. if self._untraced:
  886. h = x._mixin_handle
  887. if h < 0:
  888. raise RuntimeError("output is not computed from inputs")
  889. self._output_bindings.append(h)
  890. else:
  891. h = x._mixin_handle
  892. if h not in self._output_handles:
  893. raise RuntimeError("output is not computed from inputs")
  894. if h != self._output_bindings[i]:
  895. raise TraceMismatchError(
  896. "retval[%s] is a different tensor than last time"
  897. % (output_names and output_names[i] or i)
  898. )
  899. def get_profile(self):
  900. """
  901. Get profiling result for compiled trace.
  902. :return: a json compatible object.
  903. """
  904. if not self._profiler:
  905. raise RuntimeError("trace is not set with profiling=True")
  906. return json.loads(self._profiler.get())
  907. def trace(self, *args, **kwargs):
  908. raise NotImplementedError(
  909. "trace is deemed unbeneficial with the new "
  910. "tracing mechanism. You should alwasy use __call__."
  911. )
  912. class CompiledTensorProxy:
  913. """
  914. Duck-typed RawTensor
  915. """
  916. def __init__(self, handle):
  917. self.__handle = handle
  918. self._isscalar = False
  919. self.__info = active_trace._tinfo[handle]
  920. self.__shape = None
  921. self.__data = None
  922. self.__value = None
  923. @property
  924. def dtype(self):
  925. return self.__info.varnode.dtype
  926. @property
  927. def device(self):
  928. return self.__info.varnode.device
  929. @property
  930. def shape(self):
  931. if self._isscalar:
  932. return ()
  933. if self.__shape is None:
  934. if self.__info.shape_read:
  935. self.__shape = self.__info.shape_reader.get_value().shape
  936. elif self.__info.data_read:
  937. self.__shape = self._dev_tensor().shape
  938. else:
  939. # c++ will throw TraceReadError
  940. return None
  941. return self.__shape
  942. def numpy(self):
  943. if self.__value is None:
  944. if self.__info.value_read:
  945. self.__value = self.__info.value_reader.get_value()
  946. elif self.__info.data_read:
  947. self.__value = self._dev_tensor().numpy()
  948. else:
  949. # c++ will throw TraceReadError
  950. return None
  951. # c++ side will handle scalar case
  952. return self.__value
  953. def _dev_tensor(self):
  954. if self.__data is None:
  955. if not self.__info.data_read:
  956. # c++ will throw TraceReadError
  957. return None
  958. self.__data = self.__info.data_reader.get_value()
  959. return self.__data
  960. def __del__(self):
  961. if self.__info.shape_read and self.__shape is not None:
  962. self.__info.shape_reader.drop_value()
  963. if self.__info.value_read and self.__value is not None:
  964. self.__info.value_reader.drop_value()
  965. if self.__info.data_read and self.__data is not None:
  966. self.__info.data_reader.drop_value()
  967. def assign_raw_tensor(lhs, rhs):
  968. lhs.__init__(rhs)
  969. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  970. graph = active_trace._lazy_eval_graph
  971. ivars = []
  972. for x in args:
  973. var = getattr(x, "_varnode", None)
  974. if var:
  975. ivars.append(var)
  976. else:
  977. data_setter = G.InputNode(
  978. device=x.device,
  979. dtype=x.dtype,
  980. shape=x.numpy().shape or (1,),
  981. graph=graph,
  982. use_static_shape=True,
  983. )
  984. var = data_setter.outputs[0]
  985. ivars.append(var)
  986. data_setter.set_value(x._dev_tensor())
  987. require_links = type(op) in _io_op_types
  988. if require_links and active_trace._lazy_eval_links:
  989. assert len(ivars) > 0, "op should has at least one input"
  990. opnode = G.VirtualDepNode(
  991. [ivars[0], *active_trace._lazy_eval_links],
  992. str(active_trace._lazy_eval_links[0].device),
  993. )
  994. ivars[0] = opnode.outputs[0]
  995. active_trace._lazy_eval_links = (ivars[0],)
  996. ovars = G.apply_normal_varnode(op, *ivars)
  997. outputs = [RawTensor(o) for o in ovars]
  998. if require_links:
  999. active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
  1000. return outputs
  1001. def apply_const_symbolic_mode(value, dtype, device, name):
  1002. graph = active_trace._lazy_eval_graph
  1003. # don't need to unset tracing
  1004. # because varnode construction will ignore tracing flag
  1005. ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name))
  1006. if np.array(value).ndim == 0:
  1007. setscalar(ret)
  1008. return (ret,)
  1009. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  1010. if skip_tracing:
  1011. args = [
  1012. RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  1013. for x in args
  1014. ]
  1015. unset_tracing()
  1016. ret = apply(op, *args)
  1017. set_tracing()
  1018. return ret
  1019. return active_trace._apply_op(op, args)
  1020. def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
  1021. if skip_tracing:
  1022. unset_tracing()
  1023. ret = RawTensor(value, dtype, device, False, name)
  1024. set_tracing()
  1025. return ret
  1026. return active_trace._apply_const(value, dtype, device)
  1027. def apply_with_tracing(op: OpDef, *args: RawTensor):
  1028. if active_trace._graph:
  1029. # if member _graph exits, then is_compiled
  1030. return apply_compiled_mode(op, *args)
  1031. if hasattr(op, "scope"):
  1032. op.scope = AutoNaming.get_scope()
  1033. if active_trace._symbolic:
  1034. outputs = apply_symbolic_mode(op, *args)
  1035. else:
  1036. unset_tracing()
  1037. outputs = apply(op, *args)
  1038. set_tracing()
  1039. active_trace._record_op(op, args, outputs)
  1040. return list(outputs)
  1041. def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
  1042. if active_trace._graph:
  1043. return apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name)
  1044. if active_trace._symbolic:
  1045. outputs = apply_const_symbolic_mode(value, dtype, device, name)
  1046. else:
  1047. unset_tracing()
  1048. outputs = RawTensor(value, dtype, device, False, name)
  1049. if np.array(value).ndim == 0:
  1050. setscalar(outputs)
  1051. outputs = (outputs,)
  1052. set_tracing()
  1053. active_trace._record_const(outputs)
  1054. return list(outputs)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台