You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 42 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import contextlib
  11. import functools
  12. import itertools
  13. import json
  14. import os
  15. import typing
  16. import warnings
  17. import weakref
  18. import numpy as np
  19. from ..core._imperative_rt import GraphProfiler, common
  20. from ..core._imperative_rt.core2 import Tensor as RawTensor
  21. from ..core._imperative_rt.core2 import (
  22. TensorWeakRef,
  23. apply,
  24. set_compiled,
  25. set_tracing,
  26. skip_tracing,
  27. unset_compiled,
  28. unset_tracing,
  29. )
  30. from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend
  31. from ..core._trace_option import set_symbolic_shape
  32. from ..core._wrap import device as as_device
  33. from ..core.ops.builtin import BackwardGraph, OpDef
  34. from ..core.ops.special import Const
  35. from ..core.tensor import megbrain_graph as G
  36. from ..core.tensor.utils import setscalar
  37. from .sublinear_memory_config import SublinearMemoryConfig
  38. def _input_node_use_static_shape():
  39. return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
  40. class TraceMismatchError(RuntimeError):
  41. pass
  42. active_trace = None
  43. def is_tracing():
  44. if active_trace is None:
  45. return False
  46. else:
  47. return not skip_tracing
  48. @contextlib.contextmanager
  49. def exclude_from_trace():
  50. global skip_tracing
  51. if skip_tracing:
  52. yield
  53. return
  54. try:
  55. skip_tracing = True
  56. unset_tracing()
  57. if active_trace is not None:
  58. active_trace._begin_excluded_region()
  59. yield
  60. finally:
  61. skip_tracing = False
  62. set_tracing()
  63. class TensorInfo:
  64. __slots__ = (
  65. # collected attributes
  66. "external",
  67. "data_read",
  68. "shape_read",
  69. "value_read",
  70. "exported",
  71. "device",
  72. "dtype",
  73. "shape",
  74. "is_const",
  75. "bound_data",
  76. # resources for execution
  77. "varnode",
  78. "data_setter",
  79. "shape_reader",
  80. "value_reader",
  81. "data_reader",
  82. )
  83. def __init__(self):
  84. self.exported = None
  85. self.data_read = None
  86. self.shape_read = None
  87. self.value_read = None
  88. self.bound_data = None
  89. self.data_setter = None
  90. self.shape_reader = None
  91. self.value_reader = None
  92. self.data_reader = None
  93. _io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}
  94. class trace:
  95. """
  96. Wraps a callable and provide:
  97. * tracing via :meth:`.trace` and :meth:`.dump`
  98. * accelerated evalutaion via :meth:`.__call__`
  99. :param function: the function will be traced.
  100. :param symbolic: whether to apply symbolic execution for tracing. Default: False
  101. :param capture_as_const: capture global vars or closures as const value. Default: False
  102. :param sublinear_memory_config: configuration for sublinear memory optimization.
  103. If not None, it enables sublinear memory optimization with given setting.
  104. :param profiling: whether to profile compiled trace. Default: False
  105. :param opt_level: optimization level for compiling trace.
  106. :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
  107. """
  108. def __new__(cls, *args, **kwargs):
  109. if not args:
  110. return functools.partial(cls, **kwargs)
  111. return super().__new__(cls)
  112. def __init__(
  113. self,
  114. function,
  115. symbolic=False,
  116. capture_as_const=False,
  117. sublinear_memory_config: SublinearMemoryConfig = None,
  118. profiling: bool = False,
  119. opt_level: int = None,
  120. symbolic_shape: bool = True,
  121. ):
  122. self.__wrapped__ = function
  123. self._symbolic = symbolic
  124. self._capture_as_const = capture_as_const
  125. self._sublinear_memory_config = sublinear_memory_config
  126. self._profiling = profiling
  127. self._profiler = None
  128. self._graph_opt_level = opt_level
  129. self._symbolic_shape = symbolic_shape
  130. self._output_handles = set()
  131. self._reset()
  132. def _reset(self):
  133. self._untraced = True
  134. self._tinfo = [] # handle -> TensorInfo
  135. self._seq = []
  136. self._pc = 0
  137. self._graph = None
  138. self._need_reset_nodes = None
  139. self._lazy_eval_graph = None
  140. self._lazy_eval_tensors = {}
  141. self._lazy_eval_links = None
  142. self._active_tensors = {}
  143. self._tensor_remaps = None
  144. self._inputs_to_restore = None
  145. self._arg_bindings = None
  146. self._kwarg_bindings = None
  147. self._output_bindings = None
  148. self._output_names = None
  149. def _new_handle(self):
  150. handle = len(self._tinfo)
  151. info = TensorInfo()
  152. self._tinfo.append(info)
  153. return handle, info
  154. def _apply_op(self, op, args):
  155. assert not self._untraced
  156. # check against trace
  157. if self._pc >= len(self._seq):
  158. raise TraceMismatchError("trace should end here, but more op observed")
  159. record = self._seq[self._pc]
  160. op_, ihandles, ohandles = record
  161. if (isinstance(op_, str) and op_ == "Const") or (op != op_):
  162. raise TraceMismatchError("op different from last time")
  163. if len(ihandles) != len(args):
  164. raise TraceMismatchError("op input size different from last time")
  165. # check all inputs of crrent op
  166. for h, x in zip(ihandles, args):
  167. info = self._tinfo[h]
  168. if info.external:
  169. if (
  170. x._compiled_info is not None
  171. and not self._tinfo[x._mixin_handle].exported
  172. ):
  173. raise TraceMismatchError(
  174. "failed to capture: input was an external tensor "
  175. "last time, got an internal tensor this time"
  176. )
  177. if info.bound_data:
  178. if x._compiled_info is not None:
  179. raise TraceMismatchError(
  180. "const capture violated: was an external tensor "
  181. "last time, got an internal tensor this time"
  182. )
  183. if x._handle != info.bound_data._handle:
  184. if not np.array_equal(x.numpy(), info.bound_data.numpy()):
  185. raise TraceMismatchError(
  186. "const capture violated: got "
  187. "a different tensor this time"
  188. )
  189. else:
  190. if info.dtype != x.dtype:
  191. raise TraceMismatchError(
  192. "failed to capture: different dtype from last time"
  193. )
  194. if info.device != x.device:
  195. raise TraceMismatchError(
  196. "failed to capture: different device from last time"
  197. )
  198. info.data_setter.set_value(x._dev_tensor())
  199. else:
  200. if x._mixin_handle == -1:
  201. if x._handle not in self._tensor_remaps:
  202. raise TraceMismatchError(
  203. "unexpected capture: trying to use an external tensor as "
  204. "input, but that input was an internal tensor last time"
  205. )
  206. else:
  207. x._mixin_handle = self._tensor_remaps[
  208. x._handle
  209. ]._CompiledTensorProxy__handle
  210. if x._mixin_handle != h:
  211. raise TraceMismatchError(
  212. "mis-wiring: input edge to an data flow "
  213. "graph node is different from last time"
  214. )
  215. self._pc += 1
  216. outputs = []
  217. for h in ohandles:
  218. info = self._tinfo[h]
  219. # generate output tensor and create compied info
  220. y = RawTensor(info.varnode)
  221. y._compiled_info = CompiledTensorProxy(h)
  222. y._mixin_handle = h
  223. outputs += [y]
  224. self._active_tensors[h] = TensorWeakRef(y)
  225. self._output_handles.update(ohandles)
  226. return outputs
  227. def _apply_const(self, value, dtype, device):
  228. assert not self._untraced
  229. # check against trace
  230. if self._pc >= len(self._seq):
  231. raise TraceMismatchError("trace should end here, but more op observed")
  232. record = self._seq[self._pc]
  233. op_, ihandles, ohandles = record
  234. # Const op is represented by a str
  235. assert isinstance(op_, str) and op_ == "Const"
  236. eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
  237. if not eq:
  238. raise TraceMismatchError(
  239. "const tensor violated: got a different tensor this time"
  240. )
  241. self._pc += 1
  242. (h,) = ohandles
  243. outputs = [self._tinfo[h].bound_data]
  244. return outputs
  245. # run in first step, record information for trace
  246. def _record_op(self, op, inputs, outputs):
  247. if skip_tracing:
  248. for x in inputs:
  249. h = getattr(x, "_mixin_handle", -1)
  250. if h >= 0:
  251. self._tinfo[h].data = True
  252. return
  253. ihandles = []
  254. for x in inputs:
  255. h = getattr(x, "_mixin_handle", -1)
  256. if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
  257. h, info = self._new_handle()
  258. info.external = True
  259. info.device = x.device
  260. info.dtype = x.dtype
  261. info.shape = x.shape
  262. if self._capture_as_const:
  263. info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False)
  264. ihandles.append(h)
  265. ohandles = []
  266. for x in outputs:
  267. h, info = self._new_handle()
  268. ohandles.append(h)
  269. info.external = False
  270. x._mixin_handle = h
  271. x._recording = True
  272. x._trace_mixin_info = info
  273. self._active_tensors[h] = TensorWeakRef(x)
  274. if self._symbolic:
  275. self._lazy_eval_tensors[h] = TensorWeakRef(x)
  276. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  277. def _record_const(self, outputs):
  278. if skip_tracing:
  279. (x,) = outputs
  280. h = getattr(x, "_mixin_handle", -1)
  281. if h >= 0:
  282. self._tinfo[h].data_read = True
  283. return
  284. (x,) = outputs
  285. h, info = self._new_handle()
  286. ohandles = [h]
  287. info.external = True
  288. info.device = x.device
  289. info.dtype = x.dtype
  290. info.shape = x.shape
  291. info.bound_data = x
  292. info.is_const = True
  293. x._mixin_handle = h
  294. x._recording = True
  295. x._trace_mixin_info = info
  296. if self._symbolic:
  297. self._lazy_eval_tensors[h] = TensorWeakRef(x)
  298. self._seq.append(("Const", tuple(), tuple(ohandles)))
  299. def _set_active(self, active: bool):
  300. global active_trace
  301. if active:
  302. if active_trace:
  303. raise NotImplementedError("sorry, not implemented: nested trace")
  304. active_trace = self
  305. else:
  306. assert active_trace is self
  307. active_trace = None
  308. def _init_trace(self, symbolic: bool):
  309. if symbolic:
  310. self._lazy_eval_graph = G.Graph()
  311. self._apply_graph_options(self._lazy_eval_graph)
  312. self._lazy_eval_links = ()
  313. def _take_escaped_tensors(self):
  314. escaped_tensors = tuple(
  315. filter(lambda x: x() is not None, self._active_tensors.values())
  316. )
  317. self._active_tensors.clear()
  318. return escaped_tensors
  319. def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
  320. lazy_eval_tensors = list(
  321. filter(lambda x: x() is not None, lazy_eval_tensors.values())
  322. )
  323. readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
  324. self._apply_graph_options(lazy_eval_graph)
  325. # FIXME
  326. if self._graph_opt_level is not None:
  327. lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
  328. else:
  329. lazy_eval_graph.options.graph_opt_level = 2
  330. lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
  331. lazy_eval_graph.compile(*lazy_eval_links, *readers)
  332. lazy_eval_graph()
  333. for r, x in zip(readers, lazy_eval_tensors):
  334. # get values from lazy_eval_graph and assign to lazy_eval tensor
  335. x()._handle = RawTensor(r.op.get_value())._handle
  336. x()._reset_varnode()
  337. @contextlib.contextmanager
  338. def _setup(self):
  339. interrupted = False
  340. def do_enter():
  341. set_tracing()
  342. self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
  343. self._set_active(True)
  344. if self._untraced:
  345. self._init_trace(self._symbolic)
  346. else:
  347. set_compiled()
  348. if self._graph is None:
  349. self._compile()
  350. self._graph.execute()
  351. def do_finalize():
  352. escaped_tensors = self._take_escaped_tensors()
  353. if self._untraced:
  354. for x in escaped_tensors:
  355. if x():
  356. info = self._tinfo[x()._mixin_handle]
  357. info.data_read = True
  358. x()._mixin_handle = -1
  359. x()._recording = False
  360. if self._inputs_to_restore:
  361. for x in self._inputs_to_restore:
  362. x._mixin_handle = -1
  363. x._recording = False
  364. if self._symbolic and (
  365. self._lazy_eval_tensors or self._lazy_eval_links
  366. ):
  367. # eval lazy eval tensors
  368. self._lazy_eval(
  369. self._lazy_eval_graph,
  370. self._lazy_eval_tensors,
  371. self._lazy_eval_links,
  372. )
  373. self._lazy_eval_graph = None
  374. self._lazy_eval_tensors = None
  375. self._lazy_eval_links = None
  376. self._untraced = False
  377. else:
  378. # compiled_tensor leaks
  379. if self._pc == len(self._seq):
  380. for x in escaped_tensors:
  381. try:
  382. assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
  383. except RuntimeError:
  384. # TraceMismatchError thrown in do_exit
  385. pass
  386. self._graph.wait()
  387. self._reset_exec_env()
  388. # reset status
  389. self._pc = 0
  390. self._tensor_remaps = None
  391. self._set_active(False)
  392. set_symbolic_shape(self._save_symbolic_shape)
  393. unset_compiled()
  394. unset_tracing()
  395. def do_exit():
  396. unset_tracing()
  397. if not self._untraced and self._pc != len(self._seq):
  398. raise TraceMismatchError("premature end")
  399. if not self._symbolic or not self._untraced:
  400. # reset output tensors
  401. for x in self._active_tensors.values():
  402. if x() is not None:
  403. x()._dev_tensor()
  404. x()._reset_varnode()
  405. x()._mixin_handle = -1
  406. x()._recording = False
  407. x()._trace_mixin_info = None
  408. try:
  409. do_enter()
  410. yield
  411. do_exit()
  412. except:
  413. interrupted = True
  414. raise
  415. finally:
  416. do_finalize()
  417. if interrupted:
  418. self._reset()
  419. def _begin_excluded_region(self):
  420. if self._capture_as_const:
  421. raise RuntimeError(
  422. "exclude_from_trace cannot be used with capture_as_const"
  423. )
  424. if self._untraced:
  425. # conditionally reading a compiled tensor in excluded region
  426. # is permitted, so we have to assume every tensor might be read
  427. for x in self._active_tensors.values():
  428. if x():
  429. info = self._tinfo[x()._mixin_handle]
  430. info.exported = True
  431. info.data_read = True
  432. else:
  433. for x in self._active_tensors.values():
  434. if x():
  435. x()._dev_tensor()
  436. def _apply_graph_options(self, graph):
  437. graph.options.no_force_inplace = True
  438. graph.options.seq_opt.enable_seq_comp_node_opt = False
  439. # graph opt level
  440. # if self._graph_opt_level is not None:
  441. # graph.options.graph_opt_level = self._graph_opt_level
  442. # FIXME
  443. graph.options.graph_opt_level = 0
  444. # sublinear
  445. if self._sublinear_memory_config is not None:
  446. graph.options.enable_sublinear_memory_opt = True
  447. sublinear_config = graph.options.sublinear_mem_config
  448. sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
  449. sublinear_config.genetic_nr_iter = (
  450. self._sublinear_memory_config.genetic_nr_iter
  451. )
  452. sublinear_config.genetic_pool_size = (
  453. self._sublinear_memory_config.genetic_pool_size
  454. )
  455. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  456. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  457. # profile
  458. if self._profiling:
  459. self._profiler = GraphProfiler(graph)
  460. if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
  461. graph.options.var_sanity_check_first_run = False
  462. def _compile(self):
  463. graph = self._graph = G.Graph()
  464. graph.options.async_exec_level = 0b100
  465. self._apply_graph_options(graph)
  466. # graph.options.graph_opt_level = 0
  467. need_reset_nodes = self._need_reset_nodes = []
  468. # links enforce ordering of I/O nodes
  469. in_out_links = ()
  470. io_links = ()
  471. readers = []
  472. if self._capture_as_const:
  473. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  474. info = self._tinfo[h]
  475. opnode = info.data_setter = G.InputNode(
  476. device=info.device,
  477. dtype=info.dtype,
  478. shape=info.shape or (1,),
  479. graph=graph,
  480. use_static_shape=_input_node_use_static_shape(),
  481. )
  482. need_reset_nodes.append(opnode)
  483. info.varnode = opnode.outputs[0]
  484. in_out_links += opnode.outputs[1:]
  485. for op, ihandles, ohandles in self._seq:
  486. if isinstance(op, str) and op == "Const":
  487. assert len(ihandles) == 0
  488. (h,) = ohandles
  489. info = self._tinfo[h]
  490. if not hasattr(info, "varnode"):
  491. assert info.external
  492. assert info.bound_data
  493. info.varnode = graph.make_const(
  494. info.bound_data.numpy(),
  495. info.bound_data.dtype,
  496. info.bound_data.device,
  497. )
  498. continue
  499. require_links = type(op) in _io_op_types
  500. ivars = []
  501. for i, h in enumerate(ihandles):
  502. info = self._tinfo[h]
  503. if not hasattr(info, "varnode"):
  504. assert info.external
  505. if info.bound_data:
  506. if hasattr(info, "is_const") and info.is_const:
  507. info.varnode = graph.make_const(
  508. info.bound_data.numpy(),
  509. info.bound_data.dtype,
  510. info.bound_data.device,
  511. )
  512. else:
  513. info.varnode = graph.make_const(
  514. info.bound_data._dev_tensor()
  515. # info.bound_data.numpy()
  516. )
  517. else:
  518. opnode = info.data_setter = G.InputNode(
  519. *in_out_links,
  520. device=info.device,
  521. dtype=info.dtype,
  522. shape=info.shape or (1,),
  523. graph=graph,
  524. use_static_shape=_input_node_use_static_shape(),
  525. )
  526. need_reset_nodes.append(opnode)
  527. info.varnode, *in_out_links = opnode.outputs
  528. if require_links and i == 0 and len(io_links) > 0:
  529. opnode = G.VirtualDepNode(
  530. [info.varnode, *io_links], str(io_links[0].device)
  531. )
  532. info.varnode = opnode.outputs[0]
  533. io_links = (info.varnode,)
  534. ivars.append(info.varnode)
  535. if isinstance(op, BackwardGraph):
  536. ovars = G.apply_backward_varnode(op, *ivars)
  537. else:
  538. ovars = G.apply_normal_varnode(op, *ivars)
  539. if require_links and len(ovars) > 0:
  540. io_links = (ovars[0],)
  541. assert len(ovars) == len(ohandles)
  542. for h, v in zip(ohandles, ovars):
  543. info = self._tinfo[h]
  544. info.varnode = v
  545. def add_reader(opnode):
  546. nonlocal in_out_links
  547. need_reset_nodes.append(opnode)
  548. readers.append(opnode.outputs[0])
  549. in_out_links = opnode.outputs
  550. if info.data_read:
  551. # Shape can be obtained from data so doesn't need its own
  552. # output node. On the other hand, value is read separately
  553. # to leverage eager h2d copy
  554. info.shape_read = False
  555. opnode = info.data_reader = G.OutputNode(v, *in_out_links)
  556. add_reader(opnode)
  557. if info.value_read:
  558. opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
  559. add_reader(opnode)
  560. if info.shape_read:
  561. opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
  562. add_reader(opnode)
  563. # FIXME
  564. if self._graph_opt_level is not None:
  565. graph.options.graph_opt_level = self._graph_opt_level
  566. else:
  567. graph.options.graph_opt_level = 2
  568. graph._set_priority_to_id([*readers, *in_out_links, *io_links])
  569. graph.compile(*readers, *in_out_links, *io_links)
  570. def _reset_exec_env(self):
  571. for opnode in self._need_reset_nodes:
  572. opnode.reset()
  573. def __call__(self, *args, **kwargs):
  574. if is_tracing():
  575. return self.__wrapped__(*args, **kwargs)
  576. with self._setup():
  577. if self._capture_as_const:
  578. self._process_inputs(*args, **kwargs)
  579. outputs = self.__wrapped__(*args, **kwargs)
  580. transform = False
  581. # outputs can be None
  582. if outputs is not None:
  583. if not isinstance(outputs, collections.abc.Sequence):
  584. transform = True
  585. outputs = (outputs,)
  586. for o in outputs:
  587. # if outputs are copied, then use the newest info in trace data structure
  588. if o._copied:
  589. self._active_tensors[o._mixin_handle] = TensorWeakRef(o)
  590. if self._untraced and self._symbolic:
  591. self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o)
  592. if self._capture_as_const:
  593. self._process_outputs(outputs)
  594. if transform:
  595. outputs = outputs[0]
  596. return outputs
  597. def dump(
  598. self,
  599. file,
  600. *,
  601. arg_names=None,
  602. output_names=None,
  603. append=False,
  604. optimize_for_inference=True,
  605. **kwargs
  606. ):
  607. r"""
  608. Serializes trace to file system.
  609. :param file: output file, could be file object or filename.
  610. :param arg_names: names of the input tensors in the traced function.
  611. :param output_names: names of the output tensors in the traced function,
  612. use the default name if not specified.
  613. :param append: whether output is appended to ``file``.
  614. Only works when ``file`` is str.
  615. :param optimize_for_inference: enbale optmizations,
  616. will skip all optimize options if this is False. Default: True
  617. :Keyword Arguments:
  618. * enable_io16xc32 --
  619. whether to use float16 for I/O between oprs and use
  620. float32 as internal computation precision. Note the output var would be
  621. changed to float16.
  622. * enable_ioc16 --
  623. whether to use float16 for both I/O and computation
  624. precision.
  625. * enable_hwcd4 --
  626. whether to use NHWCD4 data layout. This is faster on some
  627. OpenCL backend.
  628. * enable_nchw88 --
  629. whether to use NCHW88 data layout, currently
  630. used in X86 AVX backend.
  631. * enable_nchw44 --
  632. whether to use NCHW44 data layout, currently
  633. used in arm backend.
  634. * enable_nchw44_dot --
  635. whether to use NCHW44_dot data layout, currently
  636. used in armv8.2+dotprod backend.
  637. * enable_nchw4 --
  638. whether to use NCHW4 data layout, currently
  639. used in nvidia backend(based on cudnn).
  640. * enable_nchw32 --
  641. whether to use NCHW32 data layout, currently
  642. used in nvidia backend with tensorcore(based on cudnn).
  643. * enable_chwn4 --
  644. whether to use CHWN4 data layout, currently
  645. used in nvidia backend with tensorcore.
  646. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  647. into one opr.
  648. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  649. input for inference on nvidia backend(this optimization pass will
  650. result in mismatch of the precision of output of training and
  651. inference)
  652. """
  653. if not self._capture_as_const:
  654. raise ValueError(
  655. "you must specify capture_as_const=True at __init__ to use dump"
  656. )
  657. if self._untraced:
  658. raise RuntimeError("should run at least once before calling dump")
  659. if self._output_names and output_names:
  660. raise TypeError(
  661. "cannot specify output_names when output is already in dict format"
  662. )
  663. if output_names and not isinstance(output_names, collections.abc.Sequence):
  664. output_names = (output_names,)
  665. if output_names and len(output_names) != len(self._output_bindings):
  666. raise ValueError(
  667. "wrong number of output_names, should be {} values".format(
  668. len(self._output_bindings)
  669. )
  670. )
  671. if arg_names is None:
  672. arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
  673. if arg_names and not isinstance(arg_names, collections.abc.Sequence):
  674. arg_names = (arg_names,)
  675. if arg_names and len(arg_names) != len(self._arg_bindings):
  676. raise ValueError(
  677. "wrong number of arg_names, should be {} values".format(
  678. len(self._arg_bindings)
  679. )
  680. )
  681. output_names = output_names or self._output_names
  682. dumped_device = as_device("xpux")
  683. h2v = {}
  684. graph = G.Graph()
  685. # only graph_opt_level takes effect in dump
  686. self._apply_graph_options(graph)
  687. for i, h in enumerate(self._arg_bindings):
  688. info = self._tinfo[h]
  689. h2v[h] = graph.make_h2d(
  690. dtype=info.dtype,
  691. device=dumped_device,
  692. shape=info.shape or (1,),
  693. name=arg_names[i] if arg_names else None,
  694. )
  695. for k, h in self._kwarg_bindings.items():
  696. info = self._tinfo[h]
  697. h2v[h] = graph.make_h2d(
  698. dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
  699. )
  700. for op, ihandles, ohandles in self._seq:
  701. if isinstance(op, str) and op == "Const":
  702. assert len(ihandles) == 0
  703. (h,) = ohandles
  704. info = self._tinfo[h]
  705. if h not in h2v:
  706. assert info.external
  707. assert info.bound_data
  708. h2v[h] = graph.make_const(
  709. info.bound_data.numpy(), dtype=info.dtype, device=info.device,
  710. )
  711. continue
  712. ivars = []
  713. for h in ihandles:
  714. info = self._tinfo[h]
  715. if h not in h2v:
  716. assert info.external
  717. assert info.bound_data
  718. h2v[h] = graph.make_const(
  719. info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
  720. )
  721. ivars.append(h2v[h])
  722. ovars = G.apply_normal_varnode(op, *ivars)
  723. assert len(ovars) == len(ohandles)
  724. h2v.update(zip(ohandles, ovars))
  725. dest_vars = []
  726. for i, h in enumerate(self._output_bindings):
  727. v = h2v[h]
  728. if output_names:
  729. v.name = output_names[i]
  730. dest_vars.append(v)
  731. if optimize_for_inference:
  732. dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
  733. if isinstance(file, str):
  734. permission = "wb" if append == False else "ab"
  735. file = open(file, permission)
  736. dump_content, dump_info = G.dump_graph(dest_vars)
  737. file.write(dump_content)
  738. return dump_info
  739. def _process_inputs(self, *args, **kwargs):
  740. if self._untraced:
  741. self._inputs_to_restore = []
  742. def record_input(x):
  743. if x is None:
  744. return
  745. h, info = self._new_handle()
  746. info.external = False
  747. info.device = x.device
  748. info.dtype = x.dtype
  749. info.shape = x.numpy().shape
  750. x._mixin_handle = h
  751. x._recording = True
  752. x._trace_mixin_info = info
  753. self._inputs_to_restore.append(x)
  754. return h
  755. self._arg_bindings = []
  756. for i, x in enumerate(args):
  757. if not isinstance(x, RawTensor):
  758. raise TypeError(
  759. "positional arguments should all be tensor "
  760. "but args[%d] cannot be recognized as one" % i
  761. )
  762. self._arg_bindings.append(record_input(x))
  763. self._kwarg_bindings = {}
  764. for k, x in kwargs.items():
  765. if isinstance(x, RawTensor):
  766. self._kwarg_bindings[k] = record_input(x)
  767. else:
  768. if len(args) != len(self._arg_bindings):
  769. raise TraceMismatchError("positional argument length mismatch")
  770. self._tensor_remaps = {}
  771. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  772. if not isinstance(x, RawTensor):
  773. raise TypeError(
  774. "positional arguments should all be tensor "
  775. "but args[%d] cannot be recognized as one" % i
  776. )
  777. info = self._tinfo[h]
  778. if x.dtype != info.dtype:
  779. raise TypeError("args[%d].dtype different from last time" % i)
  780. if x.device != info.device:
  781. raise TypeError("args[%d].device different from last time" % i)
  782. info.data_setter.set_value(x._dev_tensor())
  783. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  784. kwargs_tensors = {}
  785. for k, x in kwargs.items():
  786. if isinstance(x, RawTensor):
  787. kwargs_tensors[k] = x
  788. if set(kwargs_tensors) != set(self._kwarg_bindings):
  789. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  790. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  791. if too_many:
  792. raise TraceMismatchError(
  793. "keyword arguments found to be tensor this time "
  794. "but were non-tensor previously: %s" % " ".join(too_many)
  795. )
  796. if too_few:
  797. raise TraceMismatchError(
  798. "keyword arguments found to be non-tensor this time "
  799. "but were tensor previously: %s" % " ".join(too_few)
  800. )
  801. for k, h in self._kwarg_bindings.items():
  802. x = kwargs_tensors[k]
  803. info = self._tinfo[h]
  804. if x.dtype != info.dtype:
  805. raise TypeError("kwargs[%s].dtype different from last time" % k)
  806. if x.device != info.device:
  807. raise TypeError("kwargs[%s].device different from last time" % k)
  808. info.data_setter.set_value(x._dev_tensor())
  809. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  810. def _process_outputs(self, outputs):
  811. output_names = None
  812. if isinstance(outputs, collections.abc.Mapping):
  813. output_names, outputs = zip(*sorted(outputs.items()))
  814. elif not isinstance(outputs, collections.abc.Sequence):
  815. outputs = (outputs,)
  816. if not self._untraced:
  817. if output_names != self._output_names:
  818. too_many = set(output_names) - set(self._output_names)
  819. too_few = set(self._output_names) - set(output_names)
  820. if too_many:
  821. raise TraceMismatchError(
  822. "output has more keys than last time: %s" % " ".join(too_many)
  823. )
  824. if too_few:
  825. raise TraceMismatchError(
  826. "output has less keys than last time: %s" % " ".join(too_few)
  827. )
  828. if len(outputs) != len(self._output_bindings):
  829. raise TraceMismatchError("output size differs from last time")
  830. else:
  831. self._output_names = output_names
  832. self._output_bindings = []
  833. for i, x in enumerate(outputs):
  834. if not isinstance(x, RawTensor):
  835. raise TypeError("every item of return value should be tensor")
  836. if self._untraced:
  837. h = x._mixin_handle
  838. if h < 0:
  839. raise RuntimeError("output is not computed from inputs")
  840. self._output_bindings.append(h)
  841. else:
  842. h = x._mixin_handle
  843. if h not in self._output_handles:
  844. raise RuntimeError("output is not computed from inputs")
  845. if h != self._output_bindings[i]:
  846. raise TraceMismatchError(
  847. "retval[%s] is a different tensor than last time"
  848. % (output_names and output_names[i] or i)
  849. )
  850. def get_profile(self):
  851. """
  852. Get profiling result for compiled trace.
  853. :return: a json compatible object.
  854. """
  855. if not self._profiler:
  856. raise RuntimeError("trace is not set with profiling=True")
  857. return json.loads(self._profiler.get())
  858. def __del__(self):
  859. for x in self._tinfo:
  860. if getattr(x, "bound_data", None):
  861. x.bound_data = None
  862. def trace(self, *args, **kwargs):
  863. raise NotImplementedError(
  864. "trace is deemed unbeneficial with the new "
  865. "tracing mechanism. You should alwasy use __call__."
  866. )
  867. class CompiledTensorProxy:
  868. """
  869. Duck-typed RawTensor
  870. """
  871. def __init__(self, handle):
  872. self.__handle = handle
  873. self._isscalar = False
  874. self.__info = active_trace._tinfo[handle]
  875. self.__shape = None
  876. self.__data = None
  877. self.__value = None
  878. @property
  879. def dtype(self):
  880. return self.__info.varnode.dtype
  881. @property
  882. def device(self):
  883. return self.__info.varnode.device
  884. @property
  885. def shape(self):
  886. if self._isscalar:
  887. return ()
  888. if self.__shape is None:
  889. if self.__info.shape_read:
  890. self.__shape = self.__info.shape_reader.get_value().shape
  891. elif self.__info.data_read:
  892. self.__shape = self._dev_tensor().shape
  893. else:
  894. # c++ will throw TraceReadError
  895. return None
  896. return self.__shape
  897. def numpy(self):
  898. if self.__value is None:
  899. if self.__info.value_read:
  900. self.__value = self.__info.value_reader.get_value()
  901. elif self.__info.data_read:
  902. self.__value = self._dev_tensor().numpy()
  903. else:
  904. # c++ will throw TraceReadError
  905. return None
  906. # c++ side will handle scalar case
  907. return self.__value
  908. def _dev_tensor(self):
  909. if self.__data is None:
  910. if not self.__info.data_read:
  911. # c++ will throw TraceReadError
  912. return None
  913. self.__data = self.__info.data_reader.get_value()
  914. return self.__data
  915. def __del__(self):
  916. if self.__info.shape_read and self.__shape is not None:
  917. self.__info.shape_reader.drop_value()
  918. if self.__info.value_read and self.__value is not None:
  919. self.__info.value_reader.drop_value()
  920. if self.__info.data_read and self.__data is not None:
  921. self.__info.data_reader.drop_value()
  922. def assign_raw_tensor(lhs, rhs):
  923. lhs.__init__(rhs)
  924. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  925. graph = active_trace._lazy_eval_graph
  926. ivars = []
  927. for x in args:
  928. var = getattr(x, "_varnode", None)
  929. if var:
  930. ivars.append(var)
  931. else:
  932. data_setter = G.InputNode(
  933. device=x.device,
  934. dtype=x.dtype,
  935. shape=x.numpy().shape or (1,),
  936. graph=graph,
  937. use_static_shape=True,
  938. )
  939. var = data_setter.outputs[0]
  940. ivars.append(var)
  941. data_setter.set_value(x._dev_tensor())
  942. require_links = type(op) in _io_op_types
  943. if require_links and active_trace._lazy_eval_links:
  944. assert len(ivars) > 0, "op should has at least one input"
  945. opnode = G.VirtualDepNode(
  946. [ivars[0], *active_trace._lazy_eval_links],
  947. str(active_trace._lazy_eval_links[0].device),
  948. )
  949. ivars[0] = opnode.outputs[0]
  950. active_trace._lazy_eval_links = (ivars[0],)
  951. if isinstance(op, BackwardGraph):
  952. ovars = G.apply_backward_varnode(op, *ivars)
  953. else:
  954. ovars = G.apply_normal_varnode(op, *ivars)
  955. outputs = [RawTensor(o) for o in ovars]
  956. if require_links:
  957. active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
  958. return outputs
  959. def apply_const_symbolic_mode(value, dtype, device):
  960. graph = active_trace._lazy_eval_graph
  961. # don't need to unset tracing
  962. # because varnode construction will ignore tracing flag
  963. ret = RawTensor(graph.make_const(value, dtype=dtype, device=device))
  964. if np.array(value).ndim == 0:
  965. setscalar(ret)
  966. return (ret,)
  967. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  968. if skip_tracing:
  969. args = [
  970. RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  971. for x in args
  972. ]
  973. unset_tracing()
  974. ret = apply(op, *args)
  975. set_tracing()
  976. return ret
  977. return active_trace._apply_op(op, args)
  978. def apply_const_compiled_mode(value, dtype, device, is_const, no_cache):
  979. if skip_tracing:
  980. args = [
  981. RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  982. for x in args
  983. ]
  984. unset_tracing()
  985. ret = RawTensor(value, dtype, device, False)
  986. set_tracing()
  987. return ret
  988. return active_trace._apply_const(value, dtype, device)
  989. def apply_with_tracing(op: OpDef, *args: RawTensor):
  990. if active_trace._symbolic:
  991. outputs = apply_symbolic_mode(op, *args)
  992. else:
  993. unset_tracing()
  994. outputs = apply(op, *args)
  995. set_tracing()
  996. active_trace._record_op(op, args, outputs)
  997. return list(outputs)
  998. def apply_const_with_tracing(value, dtype, device, is_const, no_cache):
  999. if active_trace._symbolic:
  1000. outputs = apply_const_symbolic_mode(value, dtype, device)
  1001. else:
  1002. unset_tracing()
  1003. outputs = (RawTensor(value, dtype, device, False),)
  1004. set_tracing()
  1005. active_trace._record_const(outputs)
  1006. return list(outputs)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台