You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 40 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import contextlib
  11. import functools
  12. import itertools
  13. import json
  14. import os
  15. import typing
  16. import warnings
  17. import weakref
  18. import numpy as np
  19. from ..core._imperative_rt import GraphProfiler, common, put
  20. from ..core._imperative_rt.core2 import Tensor as RawTensor
  21. from ..core._imperative_rt.core2 import (
  22. TensorWeakRef,
  23. apply,
  24. call_level,
  25. set_compiled,
  26. set_symbolic,
  27. set_tracing,
  28. skip_tracing,
  29. unset_compiled,
  30. unset_symbolic,
  31. unset_tracing,
  32. )
  33. from ..core._imperative_rt.ops import (
  34. CollectiveComm,
  35. GaussianRNG,
  36. RemoteRecv,
  37. RemoteSend,
  38. UniformRNG,
  39. )
  40. from ..core._trace_option import set_symbolic_shape
  41. from ..core._wrap import device as as_device
  42. from ..core.ops.builtin import OpDef
  43. from ..core.ops.special import Const
  44. from ..core.tensor import megbrain_graph as G
  45. from .sublinear_memory_config import SublinearMemoryConfig
  46. def _input_node_use_static_shape():
  47. return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
  48. class TraceMismatchError(RuntimeError):
  49. pass
  50. active_trace = None
  51. def is_tracing():
  52. if active_trace is None:
  53. return False
  54. else:
  55. return not skip_tracing
  56. @contextlib.contextmanager
  57. def exclude_from_trace():
  58. global skip_tracing
  59. if skip_tracing:
  60. yield
  61. return
  62. try:
  63. skip_tracing = True
  64. unset_tracing()
  65. if active_trace is not None:
  66. active_trace._begin_excluded_region()
  67. yield
  68. finally:
  69. skip_tracing = False
  70. set_tracing()
  71. class TensorInfo:
  72. __slots__ = (
  73. # collected attributes
  74. "external",
  75. "exported",
  76. "device",
  77. "dtype",
  78. "shape",
  79. "is_const",
  80. "bound_data",
  81. # resources for execution
  82. "varnode",
  83. "data_setter",
  84. "shape_reader",
  85. "value_reader",
  86. "data_reader",
  87. )
  88. def __init__(self):
  89. self.exported = None
  90. self.bound_data = None
  91. self.data_setter = None
  92. self.shape_reader = None
  93. self.value_reader = None
  94. self.data_reader = None
  95. _io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}
  96. class trace:
  97. """
  98. Wraps a callable and provide:
  99. * tracing via :meth:`.trace` and :meth:`.dump`
  100. * accelerated evalutaion via :meth:`.__call__`
  101. :param function: the function will be traced.
  102. :param symbolic: whether to apply symbolic execution for tracing. Default: False
  103. :param capture_as_const: capture global vars or closures as const value. Default: False
  104. :param sublinear_memory_config: configuration for sublinear memory optimization.
  105. If not None, it enables sublinear memory optimization with given setting.
  106. :param profiling: whether to profile compiled trace. Default: False
  107. :param opt_level: optimization level for compiling trace.
  108. :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
  109. """
  110. def __new__(cls, *args, **kwargs):
  111. if not args:
  112. return functools.partial(cls, **kwargs)
  113. return super().__new__(cls)
  114. def __init__(
  115. self,
  116. function,
  117. symbolic=False,
  118. capture_as_const=False,
  119. sublinear_memory_config: SublinearMemoryConfig = None,
  120. profiling: bool = False,
  121. opt_level: int = None,
  122. symbolic_shape: bool = True,
  123. ):
  124. self.__wrapped__ = function
  125. self._symbolic = symbolic
  126. self._capture_as_const = capture_as_const
  127. self._sublinear_memory_config = sublinear_memory_config
  128. self._profiling = profiling
  129. self._profiler = None
  130. self._graph_opt_level = opt_level
  131. self._symbolic_shape = symbolic_shape
  132. self._handle2tensors = {}
  133. self._handle2compiledtensors = {}
  134. self._reset()
  135. def _reset(self):
  136. self._untraced = True
  137. self._tinfo = [] # handle -> TensorInfo
  138. self._seq = []
  139. self._pc = 0
  140. self._graph = None
  141. self._need_reset_nodes = None
  142. self._lazy_eval_graph = None
  143. self._lazy_eval_tensors = set()
  144. self._lazy_eval_links = None
  145. self._active_tensors = set()
  146. self._tensor_remaps = None
  147. self._inputs_to_restore = None
  148. self._arg_bindings = None
  149. self._kwarg_bindings = None
  150. self._output_bindings = None
  151. self._output_names = None
  152. def _new_handle(self):
  153. handle = len(self._tinfo)
  154. info = TensorInfo()
  155. self._tinfo.append(info)
  156. return handle, info
  157. def _apply_op(self, op, args):
  158. assert not self._untraced
  159. # check against trace
  160. if self._pc >= len(self._seq):
  161. raise TraceMismatchError("trace should end here, but more op observed")
  162. record = self._seq[self._pc]
  163. op_, ihandles, ohandles = record
  164. if op != op_:
  165. raise TraceMismatchError("op different from last time")
  166. if len(ihandles) != len(args):
  167. raise TraceMismatchError("op input size different from last time")
  168. for h, x in zip(ihandles, args):
  169. info = self._tinfo[h]
  170. if info.external:
  171. if (
  172. x.__class__ is CompiledTensorProxy
  173. and not self._tinfo[x._CompiledTensorProxy__handle].exported
  174. ):
  175. raise TraceMismatchError(
  176. "failed to capture: input was an external tensor "
  177. "last time, got an internal tensor this time"
  178. )
  179. if info.bound_data:
  180. if x.__class__ is CompiledTensorProxy:
  181. raise TraceMismatchError(
  182. "const capture violated: was an external tensor "
  183. "last time, got an internal tensor this time"
  184. )
  185. if x._handle != info.bound_data._handle:
  186. if not np.array_equal(x.numpy(), info.bound_data.numpy()):
  187. raise TraceMismatchError(
  188. "const capture violated: got "
  189. "a different tensor this time"
  190. )
  191. else:
  192. if info.dtype != x.dtype:
  193. raise TraceMismatchError(
  194. "failed to capture: different dtype from last time"
  195. )
  196. if info.device != x.device:
  197. raise TraceMismatchError(
  198. "failed to capture: different device from last time"
  199. )
  200. info.data_setter.set_value(x._dev_tensor())
  201. else:
  202. pass
  203. # if x.__class__ is not CompiledTensorProxy:
  204. # if x not in self._tensor_remaps:
  205. # raise TraceMismatchError(
  206. # "unexpected capture: trying to use an external tensor as "
  207. # "input, but that input was an internal tensor last time"
  208. # )
  209. # else:
  210. # x = self._tensor_remaps[x]
  211. # if x._CompiledTensorProxy__handle != h:
  212. # raise TraceMismatchError(
  213. # "mis-wiring: input edge to an data flow "
  214. # "graph node is different from last time"
  215. # )
  216. self._pc += 1
  217. for h in ohandles:
  218. t = CompiledTensorProxy(h)
  219. t._dev_tensor()
  220. self._handle2compiledtensors[h] = t
  221. outputs = [self._handle2tensors[h] for h in ohandles]
  222. self._active_tensors.update([TensorWeakRef(o) for o in outputs])
  223. return outputs
  224. def _apply_const(self, value, dtype, device):
  225. assert not self._untraced
  226. # check against trace
  227. if self._pc >= len(self._seq):
  228. raise TraceMismatchError("trace should end here, but more op observed")
  229. record = self._seq[self._pc]
  230. op_, ihandles, ohandles = record
  231. assert isinstance(op_, str) and op_ == "Const"
  232. # TODO : assert on const value
  233. # eq = value == self._tinfo[ohandles[0]].bound_data.numpy()
  234. # if not isinstance(eq, bool):
  235. # eq = all(eq)
  236. # if not eq:
  237. # raise TraceMismatchError(
  238. # "const tensor violated: got a different tensor this time"
  239. # )
  240. self._pc += 1
  241. (h,) = ohandles
  242. outputs = [self._tinfo[h].bound_data]
  243. return outputs
  244. def _record_op(self, op, inputs, outputs):
  245. if skip_tracing:
  246. for x in inputs:
  247. h = getattr(x, "mixin_handle", -1)
  248. if h >= 0:
  249. x.data_read = True
  250. return
  251. ihandles = []
  252. for x in inputs:
  253. h = getattr(x, "mixin_handle", -1)
  254. if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
  255. h, info = self._new_handle()
  256. info.external = True
  257. info.device = x.device
  258. info.dtype = x.dtype
  259. info.shape = x.shape
  260. if self._capture_as_const:
  261. info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False)
  262. ihandles.append(h)
  263. ohandles = []
  264. for x in outputs:
  265. h, info = self._new_handle()
  266. ohandles.append(h)
  267. info.external = False
  268. x.mixin_handle = h
  269. self._handle2tensors[h] = x
  270. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  271. self._active_tensors.update([TensorWeakRef(o) for o in outputs])
  272. def _record_const(self, outputs):
  273. if skip_tracing:
  274. (x,) = outputs
  275. h = getattr(x, "mixin_handle", -1)
  276. if h >= 0:
  277. x.data_read = True
  278. return
  279. (x,) = outputs
  280. h, info = self._new_handle()
  281. ohandles = [h]
  282. info.external = True
  283. info.device = x.device
  284. info.dtype = x.dtype
  285. info.shape = x.shape
  286. info.bound_data = x
  287. info.is_const = True
  288. x.mixin_handle = h
  289. self._handle2tensors[h] = x
  290. self._seq.append(("Const", tuple(), tuple(ohandles)))
  291. def _set_active(self, active: bool):
  292. global active_trace
  293. if active:
  294. if active_trace:
  295. raise NotImplementedError("sorry, not implemented: nested trace")
  296. active_trace = self
  297. else:
  298. assert active_trace is self
  299. active_trace = None
  300. def _init_trace(self, symbolic: bool):
  301. if symbolic:
  302. set_symbolic()
  303. self._lazy_eval_graph = G.Graph()
  304. self._apply_graph_options(self._lazy_eval_graph)
  305. self._lazy_eval_links = ()
  306. def _take_escaped_tensors(self):
  307. escaped_tensors = tuple(self._active_tensors)
  308. self._active_tensors.clear()
  309. return escaped_tensors
  310. def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
  311. readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
  312. self._apply_graph_options(lazy_eval_graph)
  313. # FIXME
  314. if self._graph_opt_level is not None:
  315. lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
  316. else:
  317. lazy_eval_graph.options.graph_opt_level = 2
  318. lazy_eval_graph.set_priority_to_id([*lazy_eval_links, *readers])
  319. lazy_eval_graph.compile(*lazy_eval_links, *readers)
  320. lazy_eval_graph()
  321. for r, x in zip(readers, lazy_eval_tensors):
  322. x()._handle = RawTensor(r.op.get_value())._handle
  323. @contextlib.contextmanager
  324. def _setup(self):
  325. interrupted = False
  326. def do_enter():
  327. set_tracing()
  328. self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
  329. self._set_active(True)
  330. if self._untraced:
  331. self._init_trace(self._symbolic)
  332. else:
  333. # disable symbolic mode
  334. unset_symbolic()
  335. set_compiled()
  336. if self._graph is None:
  337. self._compile()
  338. self._graph.execute()
  339. def do_finalize():
  340. escaped_tensors = self._take_escaped_tensors()
  341. if self._untraced:
  342. for x in escaped_tensors:
  343. info = self._tinfo[x().mixin_handle]
  344. x().data_read = True
  345. x().mixin_handle = -1
  346. if self._inputs_to_restore:
  347. for x in self._inputs_to_restore:
  348. x.mixin_handle = -1
  349. if self._symbolic and (
  350. self._lazy_eval_tensors or self._lazy_eval_links
  351. ):
  352. # eval lazy eval tensors
  353. self._lazy_eval(
  354. self._lazy_eval_graph,
  355. tuple(self._lazy_eval_tensors),
  356. self._lazy_eval_links,
  357. )
  358. self._lazy_eval_graph = None
  359. self._lazy_eval_tensors = None
  360. self._lazy_eval_links = None
  361. self._untraced = False
  362. else:
  363. # compiled_tensor leaks
  364. if self._pc == len(self._seq):
  365. for x in escaped_tensors:
  366. try:
  367. assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
  368. except TraceMismatchError:
  369. # TraceMismatchError thrown in do_exit
  370. pass
  371. self._graph.wait()
  372. self._reset_exec_env()
  373. # reset status
  374. self._pc = 0
  375. self._tensor_remaps = None
  376. self._set_active(False)
  377. set_symbolic_shape(self._save_symbolic_shape)
  378. unset_compiled()
  379. unset_symbolic()
  380. unset_tracing()
  381. def do_exit():
  382. unset_tracing()
  383. if not self._untraced and self._pc != len(self._seq):
  384. raise TraceMismatchError("premature end")
  385. if not self._symbolic or not self._untraced:
  386. for x in self._active_tensors:
  387. x()._dev_tensor()
  388. x().mixin_handle = -1
  389. try:
  390. do_enter()
  391. yield
  392. do_exit()
  393. except:
  394. interrupted = True
  395. raise
  396. finally:
  397. do_finalize()
  398. if interrupted:
  399. self._reset()
  400. def _begin_excluded_region(self):
  401. if self._capture_as_const:
  402. raise RuntimeError(
  403. "exclude_from_trace cannot be used with capture_as_const"
  404. )
  405. if self._untraced:
  406. # conditionally reading a compiled tensor in excluded region
  407. # is permitted, so we have to assume every tensor might be read
  408. for x in self._active_tensors:
  409. info = self._tinfo[x().mixin_handle]
  410. info.exported = True
  411. x().data_read = True
  412. def _apply_graph_options(self, graph):
  413. graph.options.no_force_inplace = True
  414. graph.options.seq_opt.enable_seq_comp_node_opt = False
  415. # graph opt level
  416. # if self._graph_opt_level is not None:
  417. # graph.options.graph_opt_level = self._graph_opt_level
  418. # FIXME
  419. graph.options.graph_opt_level = 0
  420. # sublinear
  421. if self._sublinear_memory_config is not None:
  422. graph.options.enable_sublinear_memory_opt = True
  423. sublinear_config = graph.options.sublinear_mem_config
  424. sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
  425. sublinear_config.genetic_nr_iter = (
  426. self._sublinear_memory_config.genetic_nr_iter
  427. )
  428. sublinear_config.genetic_pool_size = (
  429. self._sublinear_memory_config.genetic_pool_size
  430. )
  431. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  432. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  433. # profile
  434. if self._profiling:
  435. self._profiler = GraphProfiler(graph)
  436. def _compile(self):
  437. graph = self._graph = G.Graph()
  438. graph.options.async_exec_level = 0b100
  439. self._apply_graph_options(graph)
  440. # graph.options.graph_opt_level = 0
  441. need_reset_nodes = self._need_reset_nodes = []
  442. # links enforce ordering of I/O nodes
  443. in_out_links = ()
  444. io_links = ()
  445. readers = []
  446. if self._capture_as_const:
  447. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  448. info = self._tinfo[h]
  449. opnode = info.data_setter = G.InputNode(
  450. device=info.device,
  451. dtype=info.dtype,
  452. shape=info.shape or (1,),
  453. graph=graph,
  454. use_static_shape=_input_node_use_static_shape(),
  455. )
  456. need_reset_nodes.append(opnode)
  457. info.varnode = opnode.outputs[0]
  458. in_out_links += opnode.outputs[1:]
  459. for op, ihandles, ohandles in self._seq:
  460. if isinstance(op, str) and op == "Const":
  461. assert len(ihandles) == 0
  462. (h,) = ohandles
  463. info = self._tinfo[h]
  464. if not hasattr(info, "varnode"):
  465. assert info.external
  466. assert info.bound_data
  467. info.varnode = graph.make_const(
  468. info.bound_data.numpy(),
  469. info.bound_data.dtype,
  470. info.bound_data.device,
  471. )
  472. continue
  473. require_links = type(op) in _io_op_types
  474. ivars = []
  475. for i, h in enumerate(ihandles):
  476. info = self._tinfo[h]
  477. if not hasattr(info, "varnode"):
  478. assert info.external
  479. if info.bound_data:
  480. if hasattr(info, "is_const") and info.is_const:
  481. info.varnode = graph.make_const(
  482. info.bound_data.numpy(),
  483. info.bound_data.dtype,
  484. info.bound_data.device,
  485. )
  486. else:
  487. info.varnode = graph.make_const(
  488. info.bound_data._dev_tensor()
  489. # info.bound_data.numpy()
  490. )
  491. else:
  492. opnode = info.data_setter = G.InputNode(
  493. *in_out_links,
  494. device=info.device,
  495. dtype=info.dtype,
  496. shape=info.shape or (1,),
  497. graph=graph,
  498. use_static_shape=_input_node_use_static_shape(),
  499. )
  500. need_reset_nodes.append(opnode)
  501. info.varnode, *in_out_links = opnode.outputs
  502. if require_links and i == 0 and len(io_links) > 0:
  503. opnode = G.VirtualDepNode(
  504. [info.varnode, *io_links], str(io_links[0].device)
  505. )
  506. info.varnode = opnode.outputs[0]
  507. io_links = (info.varnode,)
  508. ivars.append(info.varnode)
  509. ivars = [RawTensor(ivar) for ivar in ivars]
  510. ovars = apply(op, *ivars)
  511. ovars = [x._varnode for x in ovars]
  512. if require_links and len(ovars) > 0:
  513. io_links = (ovars[0],)
  514. assert len(ovars) == len(ohandles)
  515. for h, v in zip(ohandles, ovars):
  516. info = self._tinfo[h]
  517. info.varnode = v
  518. def add_reader(opnode):
  519. nonlocal in_out_links
  520. need_reset_nodes.append(opnode)
  521. readers.append(opnode.outputs[0])
  522. in_out_links = opnode.outputs
  523. x = self._handle2tensors[h]
  524. if x.data_read:
  525. # Shape can be obtained from data so doesn't need its own
  526. # output node. On the other hand, value is read separately
  527. # to leverage eager h2d copy
  528. info.shape_read = False
  529. opnode = info.data_reader = G.OutputNode(v, *in_out_links)
  530. add_reader(opnode)
  531. if info.value_read:
  532. opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
  533. add_reader(opnode)
  534. if info.shape_read:
  535. opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
  536. add_reader(opnode)
  537. # FIXME
  538. if self._graph_opt_level is not None:
  539. graph.options.graph_opt_level = self._graph_opt_level
  540. else:
  541. graph.options.graph_opt_level = 2
  542. graph.set_priority_to_id([*readers, *in_out_links, *io_links])
  543. graph.compile(*readers, *in_out_links, *io_links)
  544. def _reset_exec_env(self):
  545. for opnode in self._need_reset_nodes:
  546. opnode.reset()
  547. def __call__(self, *args, **kwargs):
  548. if is_tracing():
  549. return self.__wrapped__(*args, **kwargs)
  550. with self._setup():
  551. if self._capture_as_const:
  552. self._process_inputs(*args, **kwargs)
  553. outputs = self.__wrapped__(*args, **kwargs)
  554. if self._capture_as_const:
  555. self._process_outputs(outputs)
  556. return outputs
  557. def dump(
  558. self,
  559. file,
  560. *,
  561. arg_names=None,
  562. output_names=None,
  563. append=False,
  564. optimize_for_inference=True,
  565. **kwargs
  566. ):
  567. r"""
  568. Serializes trace to file system.
  569. :param file: output file, could be file object or filename.
  570. :param arg_names: names of the input tensors in the traced function.
  571. :param output_names: names of the output tensors in the traced function,
  572. use the default name if not specified.
  573. :param append: whether output is appended to ``file``.
  574. Only works when ``file`` is str.
  575. :param optimize_for_inference: enbale optmizations,
  576. will skip all optimize options if this is False. Default: True
  577. :Keyword Arguments:
  578. * enable_io16xc32 --
  579. whether to use float16 for I/O between oprs and use
  580. float32 as internal computation precision. Note the output var would be
  581. changed to float16.
  582. * enable_ioc16 --
  583. whether to use float16 for both I/O and computation
  584. precision.
  585. * enable_hwcd4 --
  586. whether to use NHWCD4 data layout. This is faster on some
  587. OpenCL backend.
  588. * enable_nchw88 --
  589. whether to use NCHW88 data layout, currently
  590. used in X86 AVX backend.
  591. * enable_nchw44 --
  592. whether to use NCHW44 data layout, currently
  593. used in arm backend.
  594. * enable_nchw44_dot --
  595. whether to use NCHW44_dot data layout, currently
  596. used in armv8.2+dotprod backend.
  597. * enable_nchw4 --
  598. whether to use NCHW4 data layout, currently
  599. used in nvidia backend(based on cudnn).
  600. * enable_nchw32 --
  601. whether to use NCHW32 data layout, currently
  602. used in nvidia backend with tensorcore(based on cudnn).
  603. * enable_chwn4 --
  604. whether to use CHWN4 data layout, currently
  605. used in nvidia backend with tensorcore.
  606. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  607. into one opr.
  608. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  609. input for inference on nvidia backend(this optimization pass will
  610. result in mismatch of the precision of output of training and
  611. inference)
  612. """
  613. if not self._capture_as_const:
  614. raise ValueError(
  615. "you must specify capture_as_const=True at __init__ to use dump"
  616. )
  617. if self._untraced:
  618. raise RuntimeError("should run at least once before calling dump")
  619. if self._output_names and output_names:
  620. raise TypeError(
  621. "cannot specify output_names when output is already in dict format"
  622. )
  623. if output_names and not isinstance(output_names, collections.abc.Sequence):
  624. output_names = (output_names,)
  625. if output_names and len(output_names) != len(self._output_bindings):
  626. raise ValueError(
  627. "wrong number of output_names, should be {} values".format(
  628. len(self._output_bindings)
  629. )
  630. )
  631. if arg_names is None:
  632. arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
  633. if arg_names and not isinstance(arg_names, collections.abc.Sequence):
  634. arg_names = (arg_names,)
  635. if arg_names and len(arg_names) != len(self._arg_bindings):
  636. raise ValueError(
  637. "wrong number of arg_names, should be {} values".format(
  638. len(self._arg_bindings)
  639. )
  640. )
  641. output_names = output_names or self._output_names
  642. dumped_device = as_device("xpux")
  643. h2v = {}
  644. graph = G.Graph()
  645. # only graph_opt_level takes effect in dump
  646. self._apply_graph_options(graph)
  647. for i, h in enumerate(self._arg_bindings):
  648. info = self._tinfo[h]
  649. h2v[h] = graph.make_h2d(
  650. dtype=info.dtype,
  651. device=dumped_device,
  652. shape=info.shape or (1,),
  653. name=arg_names[i] if arg_names else None,
  654. )
  655. for k, h in self._kwarg_bindings.items():
  656. info = self._tinfo[h]
  657. h2v[h] = graph.make_h2d(
  658. dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
  659. )
  660. set_tracing()
  661. for op, ihandles, ohandles in self._seq:
  662. if isinstance(op, str) and op == "Const":
  663. assert len(ihandles) == 0
  664. (h,) = ohandles
  665. info = self._tinfo[h]
  666. if h not in h2v:
  667. assert info.external
  668. assert info.bound_data
  669. h2v[h] = graph.make_const(
  670. info.bound_data.numpy(), dtype=info.dtype, device=info.device,
  671. )
  672. continue
  673. ivars = []
  674. for h in ihandles:
  675. info = self._tinfo[h]
  676. if h not in h2v:
  677. assert info.external
  678. assert info.bound_data
  679. h2v[h] = graph.make_const(
  680. info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
  681. )
  682. ivars.append(h2v[h])
  683. ivars = [RawTensor(ivar) for ivar in ivars]
  684. ovars = apply(op, *ivars)
  685. ovars = [x._varnode for x in ovars]
  686. assert len(ovars) == len(ohandles)
  687. h2v.update(zip(ohandles, ovars))
  688. dest_vars = []
  689. for i, h in enumerate(self._output_bindings):
  690. v = h2v[h]
  691. if output_names:
  692. v.name = output_names[i]
  693. dest_vars.append(v)
  694. dest_vars = [G.VarNode(var) for var in dest_vars]
  695. if optimize_for_inference:
  696. dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
  697. if isinstance(file, str):
  698. permission = "wb" if append == False else "ab"
  699. file = open(file, permission)
  700. dump_content, dump_info = G.dump_graph(dest_vars)
  701. file.write(dump_content)
  702. return dump_info
  703. def _process_inputs(self, *args, **kwargs):
  704. if self._untraced:
  705. self._inputs_to_restore = []
  706. def record_input(x):
  707. if x is None:
  708. return
  709. h, info = self._new_handle()
  710. info.external = False
  711. info.device = x.device
  712. info.dtype = x.dtype
  713. info.shape = x.numpy().shape
  714. x.mixin_handle = h
  715. self._handle2tensors[h] = x
  716. self._inputs_to_restore.append(x)
  717. return h
  718. self._arg_bindings = []
  719. for i, x in enumerate(args):
  720. if not isinstance(x, RawTensor):
  721. raise TypeError(
  722. "positional arguments should all be tensor "
  723. "but args[%d] cannot be recognized as one" % i
  724. )
  725. self._arg_bindings.append(record_input(x))
  726. self._kwarg_bindings = {}
  727. for k, x in kwargs.items():
  728. if isinstance(x, RawTensor):
  729. self._kwarg_bindings[k] = record_input(x)
  730. else:
  731. if len(args) != len(self._arg_bindings):
  732. raise TraceMismatchError("positional argument length mismatch")
  733. self._tensor_remaps = {}
  734. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  735. if not isinstance(x, RawTensor):
  736. raise TypeError(
  737. "positional arguments should all be tensor "
  738. "but args[%d] cannot be recognized as one" % i
  739. )
  740. info = self._tinfo[h]
  741. if x.dtype != info.dtype:
  742. raise TypeError("args[%d].dtype different from last time" % i)
  743. if x.device != info.device:
  744. raise TypeError("args[%d].device different from last time" % i)
  745. info.data_setter.set_value(x._dev_tensor())
  746. self._tensor_remaps[x] = CompiledTensorProxy(h)
  747. kwargs_tensors = {}
  748. for k, x in kwargs.items():
  749. if isinstance(x, RawTensor):
  750. kwargs_tensors[k] = x
  751. if set(kwargs_tensors) != set(self._kwarg_bindings):
  752. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  753. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  754. if too_many:
  755. raise TraceMismatchError(
  756. "keyword arguments found to be tensor this time "
  757. "but were non-tensor previously: %s" % " ".join(too_many)
  758. )
  759. if too_few:
  760. raise TraceMismatchError(
  761. "keyword arguments found to be non-tensor this time "
  762. "but were tensor previously: %s" % " ".join(too_few)
  763. )
  764. for k, h in self._kwarg_bindings.items():
  765. x = kwargs_tensors[k]
  766. info = self._tinfo[h]
  767. if x.dtype != info.dtype:
  768. raise TypeError("kwargs[%s].dtype different from last time" % k)
  769. if x.device != info.device:
  770. raise TypeError("kwargs[%s].device different from last time" % k)
  771. info.data_setter.set_value(x._dev_tensor())
  772. self._tensor_remaps[x] = CompiledTensorProxy(h)
  773. def _process_outputs(self, outputs):
  774. output_names = None
  775. if isinstance(outputs, collections.abc.Mapping):
  776. output_names, outputs = zip(*sorted(outputs.items()))
  777. elif not isinstance(outputs, collections.abc.Sequence):
  778. outputs = (outputs,)
  779. if not self._untraced:
  780. if output_names != self._output_names:
  781. too_many = set(output_names) - set(self._output_names)
  782. too_few = set(self._output_names) - set(output_names)
  783. if too_many:
  784. raise TraceMismatchError(
  785. "output has more keys than last time: %s" % " ".join(too_many)
  786. )
  787. if too_few:
  788. raise TraceMismatchError(
  789. "output has less keys than last time: %s" % " ".join(too_few)
  790. )
  791. if len(outputs) != len(self._output_bindings):
  792. raise TraceMismatchError("output size differs from last time")
  793. else:
  794. self._output_names = output_names
  795. self._output_bindings = []
  796. for i, x in enumerate(outputs):
  797. if not isinstance(x, RawTensor):
  798. raise TypeError("every item of return value should be tensor")
  799. if self._untraced:
  800. h = x.mixin_handle
  801. if h < 0:
  802. raise RuntimeError("output is not computed from inputs")
  803. self._output_bindings.append(h)
  804. else:
  805. h = x.mixin_handle
  806. if h not in self._handle2compiledtensors:
  807. raise RuntimeError("output is not computed from inputs")
  808. if h != self._output_bindings[i]:
  809. raise TraceMismatchError(
  810. "retval[%s] is a different tensor than last time"
  811. % (output_names and output_names[i] or i)
  812. )
  813. def get_profile(self):
  814. """
  815. Get profiling result for compiled trace.
  816. :return: a json compatible object.
  817. """
  818. if not self._profiler:
  819. raise RuntimeError("trace is not set with profiling=True")
  820. return json.loads(self._profiler.get())
  821. def trace(self, *args, **kwargs):
  822. raise NotImplementedError(
  823. "trace is deemed unbeneficial with the new "
  824. "tracing mechanism. You should alwasy use __call__."
  825. )
  826. class CompiledTensorProxy:
  827. """
  828. Duck-typed RawTensor
  829. """
  830. def __init__(self, handle):
  831. self.__handle = handle
  832. self._isscalar = False
  833. self.__info = active_trace._tinfo[handle]
  834. self.__shape = None
  835. self.__data = None
  836. self.__value = None
  837. self.__tensor = active_trace._handle2tensors[handle]
  838. self.__tensor.mixin_handle = handle
  839. @property
  840. def dtype(self):
  841. return self.__info.varnode.dtype
  842. @property
  843. def device(self):
  844. return self.__info.varnode.device
  845. @property
  846. def shape(self):
  847. if self._isscalar:
  848. return ()
  849. if self.__shape is None:
  850. if self.__tensor.shape_read:
  851. self.__shape = self.__info.shape_reader.get_value().shape
  852. elif self.__tensor.data_read:
  853. self.__shape = self.__tensor._dev_tensor().shape
  854. else:
  855. raise TraceMismatchError("shape of this tensor is not read in trace")
  856. return self.__shape
  857. def numpy(self):
  858. if self.__value is None:
  859. if self.__tensor.value_read:
  860. self.__value = self.__info.value_reader.get_value()
  861. elif self.__tensor.data_read:
  862. self.__value = self._dev_tensor().numpy()
  863. else:
  864. raise TraceMismatchError("value of this tensor is not read in trace")
  865. if self._isscalar:
  866. self.__value = self.__value.squeeze()
  867. return self.__value
  868. def _dev_tensor(self):
  869. if self.__data is None:
  870. if not self.__tensor.data_read:
  871. raise TraceMismatchError("raw data of this tensor is not read in trace")
  872. self.__data = self.__info.data_reader.get_value()
  873. self.__tensor._reset(RawTensor(self.__data))
  874. self.__tensor.mixin_handle = self.__handle
  875. return self.__data
  876. def _drop(self):
  877. return
  878. def _swap_in(self):
  879. return
  880. def _swap_out(self):
  881. return
  882. def __del__(self):
  883. if self.__tensor.shape_read and self.__shape is not None:
  884. self.__info.shape_reader.drop_value()
  885. # if self.__tensor.value_read and self.__value is not None:
  886. # self.__info.value_reader.drop_value()
  887. if self.__tensor.data_read and self.__data is not None:
  888. self.__info.data_reader.drop_value()
  889. def assign_raw_tensor(lhs, rhs):
  890. lhs.__init__(rhs)
  891. # this hook turns RawTensor into LazyEvalTensor(varnode)
  892. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  893. graph = active_trace._lazy_eval_graph
  894. ivars = []
  895. for x in args:
  896. var = getattr(x, "_varnode", None)
  897. if var:
  898. ivars.append(var)
  899. else:
  900. data_setter = G.InputNode(
  901. device=x.device,
  902. dtype=x.dtype,
  903. shape=x.numpy().shape or (1,),
  904. graph=graph,
  905. use_static_shape=True,
  906. )
  907. var = data_setter.outputs[0]
  908. ivars.append(var)
  909. data_setter.set_value(x._dev_tensor())
  910. require_links = type(op) in _io_op_types
  911. if require_links and active_trace._lazy_eval_links:
  912. assert len(ivars) > 0, "op should has at least one input"
  913. opnode = G.VirtualDepNode(
  914. [ivars[0], *active_trace._lazy_eval_links],
  915. str(active_trace._lazy_eval_links[0].device),
  916. )
  917. ivars[0] = opnode.outputs[0]
  918. active_trace._lazy_eval_links = (ivars[0],)
  919. ivars = [
  920. RawTensor(ivar._node) if hasattr(ivar, "_node") else RawTensor(ivar)
  921. for ivar in ivars
  922. ]
  923. unset_symbolic()
  924. outputs = apply(op, *ivars)
  925. set_symbolic()
  926. if require_links:
  927. active_trace._lazy_eval_links = (outputs[0]._varnode,)
  928. active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs])
  929. return outputs
  930. def apply_const_symbolic_mode(value, dtype, device):
  931. graph = active_trace._lazy_eval_graph
  932. # don't need to unset tracing
  933. # because varnode construction will ignore tracing flag
  934. ret = RawTensor(graph.make_const(value, dtype=dtype, device=device))
  935. active_trace._lazy_eval_tensors.add(TensorWeakRef(ret))
  936. return (ret,)
  937. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  938. if skip_tracing:
  939. args = [
  940. RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  941. for x in args
  942. ]
  943. unset_tracing()
  944. ret = apply(op, *args)
  945. set_tracing()
  946. return ret
  947. return active_trace._apply_op(op, args)
  948. def apply_const_compiled_mode(value, dtype, device, is_const):
  949. if skip_tracing:
  950. args = [
  951. RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  952. for x in args
  953. ]
  954. unset_tracing()
  955. ret = RawTensor(value, dtype, device, False)
  956. set_tracing()
  957. return ret
  958. return active_trace._apply_const(value, dtype, device)
  959. # this hook injects TraceMixin
  960. def apply_with_tracing(op: OpDef, *args: RawTensor):
  961. if active_trace._symbolic:
  962. outputs = apply_symbolic_mode(op, *args)
  963. else:
  964. unset_tracing()
  965. outputs = apply(op, *args)
  966. set_tracing()
  967. active_trace._record_op(op, args, outputs)
  968. return list(outputs)
  969. def apply_const_with_tracing(value, dtype, device, is_const):
  970. if active_trace._symbolic:
  971. outputs = apply_const_symbolic_mode(value, dtype, device)
  972. else:
  973. unset_tracing()
  974. outputs = (RawTensor(value, dtype, device, False),)
  975. set_tracing()
  976. active_trace._record_const(outputs)
  977. return list(outputs)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台