You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 46 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import contextlib
  11. import functools
  12. import itertools
  13. import json
  14. import os
  15. import pickle
  16. from typing import Any
  17. import numpy as np
  18. from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata
  19. from ..core._imperative_rt.core2 import Tensor as RawTensor
  20. from ..core._imperative_rt.core2 import (
  21. TensorWeakRef,
  22. apply,
  23. set_tracing,
  24. skip_tracing,
  25. unset_tracing,
  26. )
  27. from ..core._imperative_rt.ops import (
  28. AssertEqual,
  29. CollectiveComm,
  30. RemoteRecv,
  31. RemoteSend,
  32. )
  33. from ..core._trace_option import set_symbolic_shape
  34. from ..core._wrap import as_device
  35. from ..core.ops.builtin import BatchNorm, OpDef
  36. from ..core.ops.special import Const
  37. from ..core.tensor import megbrain_graph as G
  38. from ..core.tensor.utils import setscalar
  39. from ..utils.naming import AutoNaming
  40. from ..utils.profiler import is_profiling
  41. from .dtr_config import DTRConfig
  42. from .graph_opt_config import GraphOptimizationConfig
  43. from .sublinear_memory_config import SublinearMemoryConfig
  44. def _input_node_use_static_shape():
  45. return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
  46. class TraceMismatchError(RuntimeError):
  47. pass
  48. active_trace = None
  49. def is_tracing():
  50. if active_trace is None:
  51. return False
  52. else:
  53. return not skip_tracing
  54. @contextlib.contextmanager
  55. def exclude_from_trace():
  56. global skip_tracing
  57. if skip_tracing:
  58. yield
  59. return
  60. try:
  61. skip_tracing = True
  62. unset_tracing()
  63. if active_trace is not None:
  64. active_trace._begin_excluded_region()
  65. yield
  66. finally:
  67. skip_tracing = False
  68. set_tracing()
  69. class TensorInfo:
  70. __slots__ = (
  71. # collected attributes
  72. "name",
  73. "external",
  74. "data_read",
  75. "shape_read",
  76. "value_read",
  77. "exported",
  78. "device",
  79. "dtype",
  80. "shape",
  81. "is_const",
  82. "bound_data",
  83. # resources for execution
  84. "varnode",
  85. "data_setter",
  86. "shape_reader",
  87. "value_reader",
  88. "data_reader",
  89. )
  90. def __init__(self):
  91. self.name = None
  92. self.exported = None
  93. self.data_read = None
  94. self.shape_read = None
  95. self.value_read = None
  96. self.bound_data = None
  97. self.data_setter = None
  98. self.shape_reader = None
  99. self.value_reader = None
  100. self.data_reader = None
  101. _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}
  102. class trace:
  103. """
  104. Wraps a callable and provide:
  105. * tracing via :meth:`.trace` and :meth:`.dump`
  106. * accelerated evalutaion via :meth:`.__call__`
  107. :param function: the function will be traced.
  108. :param symbolic: whether to apply symbolic execution for tracing. Default: False
  109. :param capture_as_const: capture global vars or closures as const value. Default: False
  110. :param sublinear_memory_config: configuration for sublinear memory optimization.
  111. If not None, it enables sublinear memory optimization with given setting.
  112. :param profiling: whether to profile compiled trace. Default: False
  113. :param opt_level: optimization level for compiling trace. Default: 2
  114. :param graph_opt_config: configuration for graph optimization. Default: None
  115. :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
  116. """
  117. def __new__(cls, *args, **kwargs):
  118. if not args:
  119. return functools.partial(cls, **kwargs)
  120. return super().__new__(cls)
  121. def __init__(
  122. self,
  123. function,
  124. symbolic=False,
  125. capture_as_const=False,
  126. sublinear_memory_config: SublinearMemoryConfig = None,
  127. dtr_config: DTRConfig = None,
  128. profiling: bool = False,
  129. opt_level: int = 2,
  130. graph_opt_config: GraphOptimizationConfig = None,
  131. symbolic_shape: bool = True,
  132. ):
  133. self.__wrapped__ = function
  134. self._symbolic = symbolic
  135. self._capture_as_const = capture_as_const
  136. self._sublinear_memory_config = sublinear_memory_config
  137. self._dtr_config = dtr_config
  138. self._profiling = profiling
  139. self._profiler = None
  140. self._profiler2 = None
  141. self._graph_opt_level = opt_level
  142. self._graph_opt_config = graph_opt_config
  143. self._symbolic_shape = symbolic_shape
  144. self._output_handles = set()
  145. self._reset()
  146. def _reset(self):
  147. self._untraced = True
  148. self._tinfo = [] # handle -> TensorInfo
  149. self._seq = []
  150. self._pc = 0
  151. self._graph = None
  152. self._need_reset_nodes = None
  153. self._lazy_eval_graph = None
  154. self._lazy_eval_tensors = set()
  155. self._lazy_eval_links = None
  156. self._active_tensors = set()
  157. self._tensor_remaps = None
  158. self._inputs_to_restore = None
  159. self._arg_bindings = None
  160. self._kwarg_bindings = None
  161. self._output_bindings = None
  162. self._output_names = None
  163. def _new_handle(self):
  164. handle = len(self._tinfo)
  165. info = TensorInfo()
  166. self._tinfo.append(info)
  167. return handle, info
  168. def _apply_op(self, op, args):
  169. assert not self._untraced
  170. # check against trace
  171. if self._pc >= len(self._seq):
  172. raise TraceMismatchError("trace should end here, but more op observed")
  173. record = self._seq[self._pc]
  174. op_, ihandles, ohandles = record
  175. if (isinstance(op_, str) and op_ == "Const") or (op != op_):
  176. raise TraceMismatchError("op different from last time")
  177. if len(ihandles) != len(args):
  178. raise TraceMismatchError("op input size different from last time")
  179. # check all inputs of crrent op
  180. for h, x in zip(ihandles, args):
  181. info = self._tinfo[h]
  182. if info.external:
  183. if (
  184. x._compiled_info is not None
  185. and not self._tinfo[x._mixin_handle].exported
  186. ):
  187. raise TraceMismatchError(
  188. "failed to capture: input was an external tensor "
  189. "last time, got an internal tensor this time"
  190. )
  191. if info.bound_data:
  192. if x._compiled_info is not None:
  193. raise TraceMismatchError(
  194. "const capture violated: was an external tensor "
  195. "last time, got an internal tensor this time"
  196. )
  197. if x._handle != info.bound_data._handle:
  198. if not np.array_equal(x.numpy(), info.bound_data.numpy()):
  199. raise TraceMismatchError(
  200. "const capture violated: got "
  201. "a different tensor this time"
  202. )
  203. else:
  204. if info.dtype != x.dtype:
  205. raise TraceMismatchError(
  206. "failed to capture: different dtype from last time"
  207. )
  208. if info.device != x.device:
  209. raise TraceMismatchError(
  210. "failed to capture: different device from last time"
  211. )
  212. info.data_setter.set_value(x._dev_tensor())
  213. else:
  214. if x._mixin_handle == -1:
  215. if x._handle not in self._tensor_remaps:
  216. raise TraceMismatchError(
  217. "unexpected capture: trying to use an external tensor as "
  218. "input, but that input was an internal tensor last time"
  219. )
  220. else:
  221. x._mixin_handle = self._tensor_remaps[
  222. x._handle
  223. ]._CompiledTensorProxy__handle
  224. if x._mixin_handle != h:
  225. raise TraceMismatchError(
  226. "mis-wiring: input edge to an data flow "
  227. "graph node is different from last time"
  228. )
  229. self._pc += 1
  230. outputs = []
  231. for h in ohandles:
  232. info = self._tinfo[h]
  233. # generate output tensor and create compied info
  234. y = RawTensor(info.varnode)
  235. y._compiled_info = CompiledTensorProxy(h)
  236. y._mixin_handle = h
  237. outputs += [y]
  238. self._active_tensors.add(TensorWeakRef(y))
  239. self._output_handles.update(ohandles)
  240. return outputs
  241. def _apply_const(self, value, dtype, device):
  242. assert not self._untraced
  243. # check against trace
  244. if self._pc >= len(self._seq):
  245. raise TraceMismatchError("trace should end here, but more op observed")
  246. record = self._seq[self._pc]
  247. op_, ihandles, ohandles = record
  248. # Const op is represented by a str
  249. assert isinstance(op_, str) and op_ == "Const"
  250. expected = self._tinfo[ohandles[0]].bound_data.numpy()
  251. shape = value.shape
  252. if shape != expected.shape or dtype != expected.dtype:
  253. eq = False
  254. elif shape == ():
  255. eq = expected.item() == value.item()
  256. elif shape == (1,):
  257. eq = expected[0] == value[0]
  258. else:
  259. eq = np.all(value == expected)
  260. if not eq:
  261. raise TraceMismatchError(
  262. "const tensor violated: got a different tensor this time"
  263. )
  264. self._pc += 1
  265. (h,) = ohandles
  266. outputs = [self._tinfo[h].bound_data]
  267. return outputs
  268. # run in first step, record information for trace
  269. def _record_op(self, op, inputs, outputs):
  270. if skip_tracing:
  271. for x in inputs:
  272. h = getattr(x, "_mixin_handle", -1)
  273. if h >= 0:
  274. self._tinfo[h].data = True
  275. return
  276. ihandles = []
  277. for x in inputs:
  278. h = getattr(x, "_mixin_handle", -1)
  279. if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
  280. h, info = self._new_handle()
  281. name = AutoNaming.gen_name(x)
  282. info.name = name
  283. info.external = True
  284. info.device = x.device
  285. info.dtype = x.dtype
  286. info.shape = x.shape
  287. if self._capture_as_const:
  288. info.bound_data = RawTensor(
  289. x.numpy(), x.dtype, x.device, False, name
  290. )
  291. ihandles.append(h)
  292. ohandles = []
  293. for x in outputs:
  294. h, info = self._new_handle()
  295. ohandles.append(h)
  296. info.external = False
  297. x._mixin_handle = h
  298. x._recording = True
  299. x._trace_mixin_info = info
  300. self._active_tensors.add(TensorWeakRef(x))
  301. if self._symbolic:
  302. self._lazy_eval_tensors.add(TensorWeakRef(x))
  303. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  304. def _record_const(self, outputs):
  305. if skip_tracing:
  306. (x,) = outputs
  307. h = getattr(x, "_mixin_handle", -1)
  308. if h >= 0:
  309. self._tinfo[h].data_read = True
  310. return
  311. (x,) = outputs
  312. h, info = self._new_handle()
  313. ohandles = [h]
  314. info.external = True
  315. info.device = x.device
  316. info.dtype = x.dtype
  317. info.shape = x.shape
  318. info.bound_data = x
  319. info.is_const = True
  320. x._mixin_handle = h
  321. x._recording = True
  322. x._trace_mixin_info = info
  323. if self._symbolic:
  324. self._lazy_eval_tensors.add(TensorWeakRef(x))
  325. self._seq.append(("Const", tuple(), tuple(ohandles)))
  326. def _set_active(self, active: bool):
  327. global active_trace
  328. if active:
  329. if active_trace:
  330. raise NotImplementedError("sorry, not implemented: nested trace")
  331. active_trace = self
  332. else:
  333. assert active_trace is self
  334. active_trace = None
  335. def _init_trace(self, symbolic: bool):
  336. if symbolic:
  337. self._lazy_eval_graph = G.Graph()
  338. self._apply_graph_options(self._lazy_eval_graph)
  339. self._lazy_eval_links = ()
  340. def _take_escaped_tensors(self):
  341. escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
  342. self._active_tensors.clear()
  343. return escaped_tensors
  344. def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
  345. lazy_eval_tensors = [x() for x in lazy_eval_tensors]
  346. lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None]
  347. readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors]
  348. self._apply_graph_options(lazy_eval_graph)
  349. lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
  350. lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
  351. lazy_eval_graph.compile(*lazy_eval_links, *readers)
  352. self._execute_graph(lazy_eval_graph)
  353. lazy_eval_graph.wait()
  354. for r, x in zip(readers, lazy_eval_tensors):
  355. # get values from lazy_eval_graph and assign to lazy_eval tensor
  356. x._handle = RawTensor(r.op.get_value())._handle
  357. x._reset_varnode()
  358. @contextlib.contextmanager
  359. def _setup(self):
  360. interrupted = False
  361. def do_enter():
  362. set_tracing()
  363. self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
  364. self._set_active(True)
  365. if self._untraced:
  366. self._init_trace(self._symbolic)
  367. else:
  368. if self._graph is None:
  369. self._compile()
  370. self._execute_graph(self._graph)
  371. def do_finalize():
  372. escaped_tensors = self._take_escaped_tensors()
  373. if self._untraced:
  374. for x in escaped_tensors:
  375. if x():
  376. info = self._tinfo[x()._mixin_handle]
  377. info.data_read = True
  378. x()._mixin_handle = -1
  379. x()._recording = False
  380. if self._inputs_to_restore:
  381. for x in self._inputs_to_restore:
  382. x._mixin_handle = -1
  383. x._recording = False
  384. if self._symbolic and (
  385. self._lazy_eval_tensors or self._lazy_eval_links
  386. ):
  387. # eval lazy eval tensors
  388. self._lazy_eval(
  389. self._lazy_eval_graph,
  390. self._lazy_eval_tensors,
  391. self._lazy_eval_links,
  392. )
  393. self._lazy_eval_graph = None
  394. self._lazy_eval_tensors = None
  395. self._lazy_eval_links = None
  396. self._untraced = False
  397. else:
  398. # compiled_tensor leaks
  399. if self._pc == len(self._seq):
  400. for x in escaped_tensors:
  401. try:
  402. assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
  403. except RuntimeError:
  404. # TraceMismatchError thrown in do_exit
  405. pass
  406. self._graph.wait()
  407. self._reset_exec_env()
  408. # reset status
  409. self._pc = 0
  410. self._tensor_remaps = None
  411. self._set_active(False)
  412. set_symbolic_shape(self._save_symbolic_shape)
  413. unset_tracing()
  414. def do_exit():
  415. unset_tracing()
  416. if not self._untraced and self._pc != len(self._seq):
  417. raise TraceMismatchError("premature end")
  418. if not self._symbolic or not self._untraced:
  419. # reset output tensors
  420. for x in self._active_tensors.copy():
  421. strong_x = x()
  422. if strong_x is not None:
  423. strong_x._dev_tensor()
  424. strong_x._reset_varnode()
  425. strong_x._mixin_handle = -1
  426. strong_x._recording = False
  427. strong_x._trace_mixin_info = None
  428. try:
  429. do_enter()
  430. yield
  431. do_exit()
  432. except:
  433. interrupted = True
  434. raise
  435. finally:
  436. do_finalize()
  437. if interrupted:
  438. self._reset()
  439. def _begin_excluded_region(self):
  440. if self._capture_as_const:
  441. raise RuntimeError(
  442. "exclude_from_trace cannot be used with capture_as_const"
  443. )
  444. if self._untraced:
  445. # conditionally reading a compiled tensor in excluded region
  446. # is permitted, so we have to assume every tensor might be read
  447. for x in self._active_tensors:
  448. strong_x = x()
  449. if strong_x:
  450. info = self._tinfo[strong_x._mixin_handle]
  451. info.exported = True
  452. info.data_read = True
  453. else:
  454. for x in self._active_tensors:
  455. strong_x = x()
  456. if strong_x:
  457. strong_x._dev_tensor()
  458. def _apply_graph_options(self, graph):
  459. graph.options.no_force_inplace = True
  460. graph.options.seq_opt.enable_seq_comp_node_opt = False
  461. graph.options.graph_opt_level = self._graph_opt_level
  462. if self._dtr_config is not None:
  463. graph.options.enable_dtr_memory_opt = True
  464. graph.options.dtr_config.eviction_threshold = (
  465. self._dtr_config.eviction_threshold
  466. )
  467. graph.options.dtr_config.evictee_minimum_size = (
  468. self._dtr_config.evictee_minimum_size
  469. )
  470. # graph optimization
  471. if self._graph_opt_config is not None:
  472. mapping = {None: 0, False: 1, True: 2}
  473. jit_config = graph.options.graph_opt.jit_config
  474. jit_config.fuse_dimshuffle = mapping[
  475. self._graph_opt_config.jit_fuse_dimshuffle
  476. ]
  477. jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce]
  478. # sublinear
  479. if self._sublinear_memory_config is not None:
  480. graph.options.enable_sublinear_memory_opt = True
  481. sublinear_config = graph.options.sublinear_mem_config
  482. sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
  483. sublinear_config.genetic_nr_iter = (
  484. self._sublinear_memory_config.genetic_nr_iter
  485. )
  486. sublinear_config.genetic_pool_size = (
  487. self._sublinear_memory_config.genetic_pool_size
  488. )
  489. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  490. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  491. # profile
  492. if self._profiling:
  493. self._profiler = GraphProfiler(graph)
  494. self._profiler2 = None
  495. if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
  496. graph.options.var_sanity_check_first_run = False
  497. def _execute_graph(self, graph: G.Graph, *args):
  498. if is_profiling() and (self._profiler2 is None):
  499. self._profiler2 = GraphProfiler2(graph)
  500. elif not is_profiling() and (self._profiler2 is not None):
  501. self._profiler2 = None
  502. graph.execute(*args)
  503. def _compile(self):
  504. graph = self._graph = G.Graph()
  505. graph.options.async_exec_level = 0b100
  506. self._apply_graph_options(graph)
  507. need_reset_nodes = self._need_reset_nodes = []
  508. # links enforce ordering of I/O nodes
  509. in_out_links = ()
  510. io_links = ()
  511. readers = []
  512. if self._capture_as_const:
  513. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  514. info = self._tinfo[h]
  515. opnode = info.data_setter = G.InputNode(
  516. device=info.device,
  517. dtype=info.dtype,
  518. shape=info.shape or (1,),
  519. graph=graph,
  520. use_static_shape=_input_node_use_static_shape(),
  521. )
  522. need_reset_nodes.append(opnode)
  523. info.varnode = opnode.outputs[0]
  524. in_out_links += opnode.outputs[1:]
  525. for op, ihandles, ohandles in self._seq:
  526. if isinstance(op, str) and op == "Const":
  527. assert len(ihandles) == 0
  528. (h,) = ohandles
  529. info = self._tinfo[h]
  530. if not hasattr(info, "varnode"):
  531. assert info.external
  532. assert info.bound_data
  533. info.varnode = graph.make_const(
  534. info.bound_data.numpy(),
  535. info.bound_data.dtype,
  536. info.bound_data.device,
  537. )
  538. continue
  539. require_links = type(op) in _io_op_types
  540. ivars = []
  541. for i, h in enumerate(ihandles):
  542. info = self._tinfo[h]
  543. if not hasattr(info, "varnode"):
  544. assert info.external
  545. if info.bound_data:
  546. if getattr(info, "is_const", False):
  547. info.varnode = graph.make_const(
  548. info.bound_data.numpy(),
  549. info.bound_data.dtype,
  550. info.bound_data.device,
  551. )
  552. else:
  553. info.varnode = graph.make_const(
  554. info.bound_data._dev_tensor()
  555. # info.bound_data.numpy()
  556. )
  557. else:
  558. opnode = info.data_setter = G.InputNode(
  559. *in_out_links,
  560. device=info.device,
  561. dtype=info.dtype,
  562. shape=info.shape or (1,),
  563. graph=graph,
  564. use_static_shape=_input_node_use_static_shape(),
  565. )
  566. need_reset_nodes.append(opnode)
  567. info.varnode, *in_out_links = opnode.outputs
  568. if require_links and i == 0 and len(io_links) > 0:
  569. opnode = G.VirtualDepNode(
  570. [info.varnode, *io_links], str(io_links[0].device)
  571. )
  572. info.varnode = opnode.outputs[0]
  573. io_links = (info.varnode,)
  574. ivars.append(info.varnode)
  575. ovars = G.apply_normal_varnode(op, *ivars)
  576. if require_links and len(ovars) > 0:
  577. io_links = (ovars[0],)
  578. assert len(ovars) == len(ohandles)
  579. for h, v in zip(ohandles, ovars):
  580. info = self._tinfo[h]
  581. info.varnode = v
  582. def add_reader(opnode):
  583. nonlocal in_out_links
  584. need_reset_nodes.append(opnode)
  585. readers.append(opnode.outputs[0])
  586. in_out_links = opnode.outputs
  587. if info.data_read:
  588. # Shape can be obtained from data so doesn't need its own
  589. # output node. On the other hand, value is read separately
  590. # to leverage eager h2d copy
  591. info.shape_read = False
  592. opnode = info.data_reader = G.OutputNode(v, *in_out_links)
  593. add_reader(opnode)
  594. if info.value_read:
  595. opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
  596. add_reader(opnode)
  597. if info.shape_read:
  598. opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
  599. add_reader(opnode)
  600. graph.options.graph_opt_level = self._graph_opt_level
  601. graph._set_priority_to_id([*readers, *in_out_links, *io_links])
  602. graph.compile(*readers, *in_out_links, *io_links)
  603. def _reset_exec_env(self):
  604. for opnode in self._need_reset_nodes:
  605. opnode.reset()
  606. def __call__(self, *args, **kwargs):
  607. with self._setup():
  608. if self._capture_as_const:
  609. self._process_inputs(*args, **kwargs)
  610. outputs = self.__wrapped__(*args, **kwargs)
  611. if self._capture_as_const:
  612. self._process_outputs(outputs)
  613. return outputs
  614. def dump(
  615. self,
  616. file,
  617. *,
  618. arg_names=None,
  619. output_names=None,
  620. append=False,
  621. keep_var_name: int = 1,
  622. keep_opr_name: bool = False,
  623. keep_param_name: bool = False,
  624. keep_opr_priority: bool = False,
  625. strip_info_file=None,
  626. append_json=False,
  627. optimize_for_inference=True,
  628. user_info: Any = None,
  629. enable_metadata: bool = True,
  630. **kwargs
  631. ):
  632. r"""
  633. Serializes trace to file system.
  634. :param file: output file, could be file object or filename.
  635. :param arg_names: names of the input tensors in the traced function.
  636. :param output_names: names of the output tensors in the traced function,
  637. use the default name if not specified.
  638. :param append: whether output is appended to ``file``.
  639. Only works when ``file`` is str.
  640. :param keep_var_name: level for keeping variable names:
  641. * 0: none of the names are kept
  642. * 1: (default)keep names of output vars
  643. * 2: keep names of all (output and internal) vars
  644. :param keep_opr_name: whether to keep operator names.
  645. :param keep_param_name: whether to keep param names, so param values can be
  646. easily manipulated after loading model
  647. :param keep_opr_priority: whether to keep priority setting for operators
  648. :param strip_info_file: a string for path or a file handler. if is not None,
  649. then the dump information for code strip would be written to ``strip_info_file``
  650. :param append_json: will be check when `strip_info_file` is not None. if set
  651. true, the information for code strip will be append to strip_info_file.
  652. if set false, will rewrite strip_info_file
  653. :param optimize_for_inference: enbale optmizations,
  654. will skip all optimize options if this is False. Default: True
  655. :param user_info: any type object, which will be pickled to bytes.
  656. :param enable_metadata: whether to save metadata into output file.
  657. :Keyword Arguments:
  658. * enable_io16xc32 --
  659. whether to use float16 for I/O between oprs and use
  660. float32 as internal computation precision. Note the output var would be
  661. changed to float16.
  662. * enable_ioc16 --
  663. whether to use float16 for both I/O and computation
  664. precision.
  665. * enable_hwcd4 --
  666. whether to use NHWCD4 data layout. This is faster on some
  667. OpenCL backend.
  668. * enable_nchw88 --
  669. whether to use NCHW88 data layout, currently
  670. used in X86 AVX backend.
  671. * enable_nchw44 --
  672. whether to use NCHW44 data layout, currently
  673. used in arm backend.
  674. * enable_nchw44_dot --
  675. whether to use NCHW44_dot data layout, currently
  676. used in armv8.2+dotprod backend.
  677. * enable_nchw4 --
  678. whether to use NCHW4 data layout, currently
  679. used in nvidia backend(based on cudnn).
  680. * enable_nchw32 --
  681. whether to use NCHW32 data layout, currently
  682. used in nvidia backend with tensorcore(based on cudnn).
  683. * enable_chwn4 --
  684. whether to use CHWN4 data layout, currently
  685. used in nvidia backend with tensorcore.
  686. * enable_nchw64 --
  687. whether to use NCHW64 data layout, used for fast int4
  688. support on Nvidia GPU.
  689. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  690. into one opr.
  691. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  692. input for inference on nvidia backend(this optimization pass will
  693. result in mismatch of the precision of output of training and
  694. inference)
  695. """
  696. if not self._capture_as_const:
  697. raise ValueError(
  698. "you must specify capture_as_const=True at __init__ to use dump"
  699. )
  700. if self._untraced:
  701. raise RuntimeError("should run at least once before calling dump")
  702. if self._output_names and output_names:
  703. raise TypeError(
  704. "cannot specify output_names when output is already in dict format"
  705. )
  706. if output_names and not isinstance(output_names, collections.abc.Sequence):
  707. output_names = (output_names,)
  708. if output_names and len(output_names) != len(self._output_bindings):
  709. raise ValueError(
  710. "wrong number of output_names, should be {} values".format(
  711. len(self._output_bindings)
  712. )
  713. )
  714. without_arg_names = arg_names is None
  715. if without_arg_names:
  716. arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
  717. if arg_names and not isinstance(arg_names, collections.abc.Sequence):
  718. arg_names = (arg_names,)
  719. if arg_names and len(arg_names) != len(self._arg_bindings):
  720. raise ValueError(
  721. "wrong number of arg_names, should be {} values".format(
  722. len(self._arg_bindings)
  723. )
  724. )
  725. output_names = output_names or self._output_names
  726. def dumped_device(info):
  727. device_name = info.device.logical_name
  728. if device_name[:3] in ("cpu", "gpu", "xpu"):
  729. return as_device("xpux")
  730. return info.device
  731. h2v = {}
  732. graph = G.Graph()
  733. # apply graph_opt_level in dump
  734. if self._graph_opt_level is not None:
  735. graph.options.graph_opt_level = self._graph_opt_level
  736. for i, h in enumerate(self._arg_bindings):
  737. info = self._tinfo[h]
  738. h2v[h] = graph.make_h2d(
  739. dtype=info.dtype,
  740. device=dumped_device(info),
  741. shape=info.shape or (1,),
  742. name=info.name if without_arg_names and info.name else arg_names[i],
  743. )
  744. for k, h in self._kwarg_bindings.items():
  745. info = self._tinfo[h]
  746. h2v[h] = graph.make_h2d(
  747. dtype=info.dtype,
  748. device=dumped_device(info),
  749. shape=info.shape or (1,),
  750. name=k,
  751. )
  752. for op, ihandles, ohandles in self._seq:
  753. if isinstance(op, str) and op == "Const":
  754. assert len(ihandles) == 0
  755. (h,) = ohandles
  756. info = self._tinfo[h]
  757. if h not in h2v:
  758. assert info.external
  759. assert info.bound_data
  760. h2v[h] = graph.make_const(
  761. info.bound_data.numpy(),
  762. dtype=info.dtype,
  763. device=dumped_device(info),
  764. name=info.name,
  765. )
  766. continue
  767. ivars = []
  768. for h in ihandles:
  769. info = self._tinfo[h]
  770. if h not in h2v:
  771. assert info.external
  772. assert info.bound_data
  773. h2v[h] = graph.make_const(
  774. info.bound_data.numpy(),
  775. dtype=info.dtype,
  776. device=dumped_device(info),
  777. name=info.name,
  778. )
  779. ivars.append(h2v[h])
  780. if isinstance(op, BatchNorm):
  781. assert (
  782. op.fwd_mode == BatchNorm.FwdMode.INFERENCE
  783. ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
  784. ovars = G.apply_normal_varnode(op, *ivars)
  785. AutoNaming.record_opnode(ovars[0].op)
  786. assert len(ovars) == len(ohandles)
  787. h2v.update(zip(ohandles, ovars))
  788. for i in ohandles:
  789. name = AutoNaming.get_var_name(i)
  790. if name is not None:
  791. h2v[i].name = name
  792. AutoNaming.remove_duplicate_names()
  793. dest_vars = []
  794. for i, h in enumerate(self._output_bindings):
  795. v = h2v[h]
  796. if output_names:
  797. v.name = output_names[i]
  798. dest_vars.append(v)
  799. if optimize_for_inference:
  800. dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)
  801. metadata = SerializationMetadata()
  802. if enable_metadata:
  803. metadata.user_info = pickle.dumps(user_info)
  804. metadata.is_valid = True
  805. metadata.graph_modified = False
  806. if optimize_for_inference:
  807. metadata.optimize_options = optimize_options
  808. if isinstance(file, str):
  809. permission = "wb" if append == False else "ab"
  810. file = open(file, permission)
  811. if keep_opr_priority:
  812. graph._set_priority_to_id(dest_vars)
  813. dump_content, dump_info = G.dump_graph(
  814. dest_vars,
  815. keep_var_name=keep_var_name,
  816. keep_opr_name=keep_opr_name,
  817. keep_param_name=keep_param_name,
  818. keep_opr_priority=keep_opr_priority,
  819. strip_info_file=strip_info_file,
  820. append_json=append_json,
  821. metadata=metadata,
  822. )
  823. file.write(dump_content)
  824. return dump_info
  825. def _process_inputs(self, *args, **kwargs):
  826. if self._untraced:
  827. self._inputs_to_restore = []
  828. def record_input(x):
  829. if x is None:
  830. return
  831. h, info = self._new_handle()
  832. info.external = False
  833. info.name = x.c_name
  834. info.device = x.device
  835. info.dtype = x.dtype
  836. info.shape = x.numpy().shape
  837. x._mixin_handle = h
  838. x._recording = True
  839. x._trace_mixin_info = info
  840. self._inputs_to_restore.append(x)
  841. return h
  842. self._arg_bindings = []
  843. for i, x in enumerate(args):
  844. if not isinstance(x, RawTensor):
  845. raise TypeError(
  846. "positional arguments should all be tensor "
  847. "but args[%d] cannot be recognized as one" % i
  848. )
  849. self._arg_bindings.append(record_input(x))
  850. self._kwarg_bindings = {}
  851. for k, x in kwargs.items():
  852. if isinstance(x, RawTensor):
  853. self._kwarg_bindings[k] = record_input(x)
  854. else:
  855. if len(args) != len(self._arg_bindings):
  856. raise TraceMismatchError("positional argument length mismatch")
  857. self._tensor_remaps = {}
  858. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  859. if not isinstance(x, RawTensor):
  860. raise TypeError(
  861. "positional arguments should all be tensor "
  862. "but args[%d] cannot be recognized as one" % i
  863. )
  864. info = self._tinfo[h]
  865. if x.dtype != info.dtype:
  866. raise TypeError("args[%d].dtype different from last time" % i)
  867. if x.device != info.device:
  868. raise TypeError("args[%d].device different from last time" % i)
  869. info.data_setter.set_value(x._dev_tensor())
  870. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  871. kwargs_tensors = {}
  872. for k, x in kwargs.items():
  873. if isinstance(x, RawTensor):
  874. kwargs_tensors[k] = x
  875. if set(kwargs_tensors) != set(self._kwarg_bindings):
  876. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  877. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  878. if too_many:
  879. raise TraceMismatchError(
  880. "keyword arguments found to be tensor this time "
  881. "but were non-tensor previously: %s" % " ".join(too_many)
  882. )
  883. if too_few:
  884. raise TraceMismatchError(
  885. "keyword arguments found to be non-tensor this time "
  886. "but were tensor previously: %s" % " ".join(too_few)
  887. )
  888. for k, h in self._kwarg_bindings.items():
  889. x = kwargs_tensors[k]
  890. info = self._tinfo[h]
  891. if x.dtype != info.dtype:
  892. raise TypeError("kwargs[%s].dtype different from last time" % k)
  893. if x.device != info.device:
  894. raise TypeError("kwargs[%s].device different from last time" % k)
  895. info.data_setter.set_value(x._dev_tensor())
  896. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  897. def _process_outputs(self, outputs):
  898. output_names = None
  899. if isinstance(outputs, collections.abc.Mapping):
  900. output_names, outputs = zip(*sorted(outputs.items()))
  901. elif not isinstance(outputs, collections.abc.Sequence):
  902. outputs = (outputs,)
  903. if not self._untraced:
  904. if output_names != self._output_names:
  905. too_many = set(output_names) - set(self._output_names)
  906. too_few = set(self._output_names) - set(output_names)
  907. if too_many:
  908. raise TraceMismatchError(
  909. "output has more keys than last time: %s" % " ".join(too_many)
  910. )
  911. if too_few:
  912. raise TraceMismatchError(
  913. "output has less keys than last time: %s" % " ".join(too_few)
  914. )
  915. if len(outputs) != len(self._output_bindings):
  916. raise TraceMismatchError("output size differs from last time")
  917. else:
  918. self._output_names = output_names
  919. self._output_bindings = []
  920. for i, x in enumerate(outputs):
  921. if not isinstance(x, RawTensor):
  922. raise TypeError("every item of return value should be tensor")
  923. if self._untraced:
  924. h = x._mixin_handle
  925. if h < 0:
  926. raise RuntimeError("output is not computed from inputs")
  927. self._output_bindings.append(h)
  928. else:
  929. h = x._mixin_handle
  930. if h not in self._output_handles:
  931. raise RuntimeError("output is not computed from inputs")
  932. if h != self._output_bindings[i]:
  933. raise TraceMismatchError(
  934. "retval[%s] is a different tensor than last time"
  935. % (output_names and output_names[i] or i)
  936. )
  937. def get_profile(self):
  938. """
  939. Get profiling result for compiled trace.
  940. :return: a json compatible object.
  941. """
  942. if not self._profiler:
  943. raise RuntimeError("trace is not set with profiling=True")
  944. return json.loads(self._profiler.get())
  945. def trace(self, *args, **kwargs):
  946. raise NotImplementedError(
  947. "trace is deemed unbeneficial with the new "
  948. "tracing mechanism. You should alwasy use __call__."
  949. )
  950. class CompiledTensorProxy:
  951. """
  952. Duck-typed RawTensor
  953. """
  954. def __init__(self, handle):
  955. self.__handle = handle
  956. self._isscalar = False
  957. self.__info = active_trace._tinfo[handle]
  958. self.__shape = None
  959. self.__data = None
  960. self.__value = None
  961. @property
  962. def dtype(self):
  963. return self.__info.varnode.dtype
  964. @property
  965. def device(self):
  966. return self.__info.varnode.device
  967. @property
  968. def shape(self):
  969. if self._isscalar:
  970. return ()
  971. if self.__shape is None:
  972. if self.__info.shape_read:
  973. self.__shape = self.__info.shape_reader.get_value().shape
  974. elif self.__info.data_read:
  975. self.__shape = self._dev_tensor().shape
  976. else:
  977. # c++ will throw TraceReadError
  978. return None
  979. return self.__shape
  980. def numpy(self):
  981. if self.__value is None:
  982. if self.__info.value_read:
  983. self.__value = self.__info.value_reader.get_value()
  984. elif self.__info.data_read:
  985. self.__value = self._dev_tensor().numpy()
  986. else:
  987. # c++ will throw TraceReadError
  988. return None
  989. # c++ side will handle scalar case
  990. return self.__value
  991. def _dev_tensor(self):
  992. if self.__data is None:
  993. if not self.__info.data_read:
  994. # c++ will throw TraceReadError
  995. return None
  996. self.__data = self.__info.data_reader.get_value()
  997. return self.__data
  998. def __del__(self):
  999. if self.__info.shape_read and self.__shape is not None:
  1000. self.__info.shape_reader.drop_value()
  1001. if self.__info.value_read and self.__value is not None:
  1002. self.__info.value_reader.drop_value()
  1003. if self.__info.data_read and self.__data is not None:
  1004. self.__info.data_reader.drop_value()
  1005. def assign_raw_tensor(lhs, rhs):
  1006. lhs.__init__(rhs)
  1007. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  1008. graph = active_trace._lazy_eval_graph
  1009. ivars = []
  1010. for x in args:
  1011. var = getattr(x, "_varnode", None)
  1012. if var:
  1013. ivars.append(var)
  1014. else:
  1015. data_setter = G.InputNode(
  1016. device=x.device,
  1017. dtype=x.dtype,
  1018. shape=x.numpy().shape or (1,),
  1019. graph=graph,
  1020. use_static_shape=True,
  1021. )
  1022. var = data_setter.outputs[0]
  1023. ivars.append(var)
  1024. data_setter.set_value(x._dev_tensor())
  1025. require_links = type(op) in _io_op_types
  1026. if require_links and active_trace._lazy_eval_links:
  1027. assert len(ivars) > 0, "op should has at least one input"
  1028. opnode = G.VirtualDepNode(
  1029. [ivars[0], *active_trace._lazy_eval_links],
  1030. str(active_trace._lazy_eval_links[0].device),
  1031. )
  1032. ivars[0] = opnode.outputs[0]
  1033. active_trace._lazy_eval_links = (ivars[0],)
  1034. ovars = G.apply_normal_varnode(op, *ivars)
  1035. outputs = [RawTensor(o) for o in ovars]
  1036. if require_links:
  1037. active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
  1038. return outputs
  1039. def apply_const_symbolic_mode(value, dtype, device, name):
  1040. graph = active_trace._lazy_eval_graph
  1041. # don't need to unset tracing
  1042. # because varnode construction will ignore tracing flag
  1043. ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name))
  1044. if np.array(value).ndim == 0:
  1045. setscalar(ret)
  1046. return (ret,)
  1047. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  1048. if skip_tracing:
  1049. args = [
  1050. RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  1051. for x in args
  1052. ]
  1053. unset_tracing()
  1054. ret = apply(op, *args)
  1055. set_tracing()
  1056. return ret
  1057. return active_trace._apply_op(op, args)
  1058. def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
  1059. if skip_tracing:
  1060. unset_tracing()
  1061. ret = RawTensor(value, dtype, device, False, name)
  1062. set_tracing()
  1063. return ret
  1064. return active_trace._apply_const(value, dtype, device)
  1065. def apply_with_tracing(op: OpDef, *args: RawTensor):
  1066. if active_trace._graph:
  1067. # if member _graph exits, then is_compiled
  1068. return apply_compiled_mode(op, *args)
  1069. if hasattr(op, "scope"):
  1070. op.scope = AutoNaming.get_scope()
  1071. if active_trace._symbolic:
  1072. outputs = apply_symbolic_mode(op, *args)
  1073. else:
  1074. unset_tracing()
  1075. outputs = apply(op, *args)
  1076. set_tracing()
  1077. active_trace._record_op(op, args, outputs)
  1078. return list(outputs)
  1079. def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
  1080. if active_trace._graph:
  1081. return apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name)
  1082. if active_trace._symbolic:
  1083. outputs = apply_const_symbolic_mode(value, dtype, device, name)
  1084. else:
  1085. unset_tracing()
  1086. outputs = RawTensor(value, dtype, device, False, name)
  1087. if np.array(value).ndim == 0:
  1088. setscalar(outputs)
  1089. outputs = (outputs,)
  1090. set_tracing()
  1091. active_trace._record_const(outputs)
  1092. return list(outputs)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台