You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 43 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import contextlib
  11. import functools
  12. import itertools
  13. import json
  14. import os
  15. import typing
  16. import warnings
  17. import weakref
  18. import numpy as np
  19. from ..core._imperative_rt import GraphProfiler
  20. from ..core._imperative_rt.core2 import Tensor
  21. from ..core._imperative_rt.ops import (
  22. CollectiveComm,
  23. GaussianRNG,
  24. RemoteRecv,
  25. RemoteSend,
  26. UniformRNG,
  27. )
  28. from ..core._trace_option import set_symbolic_shape
  29. from ..core._wrap import device as as_device
  30. from ..core.ops.special import Const
  31. from ..core.tensor import megbrain_graph as G
  32. from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
  33. from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
  34. from .sublinear_memory_config import SublinearMemoryConfig
  35. def _input_node_use_static_shape():
  36. return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
  37. class TraceMismatchError(RuntimeError):
  38. pass
  39. active_trace = None
  40. skip_tracing = False
  41. def is_tracing():
  42. if active_trace is None:
  43. return False
  44. else:
  45. return not skip_tracing
  46. @contextlib.contextmanager
  47. def exclude_from_trace():
  48. global skip_tracing
  49. if skip_tracing:
  50. yield
  51. return
  52. try:
  53. skip_tracing = True
  54. if active_trace is not None:
  55. active_trace._begin_excluded_region()
  56. yield
  57. finally:
  58. skip_tracing = False
  59. class TensorInfo:
  60. __slots__ = (
  61. # collected attributes
  62. "external",
  63. "exported",
  64. "data_read",
  65. "shape_read",
  66. "value_read",
  67. "device",
  68. "dtype",
  69. "shape",
  70. "is_const",
  71. "bound_data",
  72. # resources for execution
  73. "varnode",
  74. "data_setter",
  75. "shape_reader",
  76. "value_reader",
  77. "data_reader",
  78. )
  79. def __init__(self):
  80. self.exported = None
  81. self.data_read = None
  82. self.shape_read = None
  83. self.value_read = None
  84. self.bound_data = None
  85. self.data_setter = None
  86. self.shape_reader = None
  87. self.value_reader = None
  88. self.data_reader = None
  89. _io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}
  90. class trace:
  91. """
  92. Wraps a callable and provide:
  93. * tracing via :meth:`.trace` and :meth:`.dump`
  94. * accelerated evalutaion via :meth:`.__call__`
  95. :param function: the function will be traced.
  96. :param symbolic: whether to apply symbolic execution for tracing. Default: False
  97. :param capture_as_const: capture global vars or closures as const value. Default: False
  98. :param sublinear_memory_config: configuration for sublinear memory optimization.
  99. If not None, it enables sublinear memory optimization with given setting.
  100. :param profiling: whether to profile compiled trace. Default: False
  101. :param opt_level: optimization level for compiling trace.
  102. :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
  103. """
  104. def __new__(cls, *args, **kwargs):
  105. if not args:
  106. return functools.partial(cls, **kwargs)
  107. return super().__new__(cls)
  108. def __init__(
  109. self,
  110. function,
  111. symbolic=False,
  112. capture_as_const=False,
  113. sublinear_memory_config: SublinearMemoryConfig = None,
  114. profiling: bool = False,
  115. opt_level: int = None,
  116. symbolic_shape: bool = True,
  117. ):
  118. self.__wrapped__ = function
  119. self._symbolic = symbolic
  120. self._capture_as_const = capture_as_const
  121. self._sublinear_memory_config = sublinear_memory_config
  122. self._profiling = profiling
  123. self._profiler = None
  124. self._graph_opt_level = opt_level
  125. self._symbolic_shape = symbolic_shape
  126. self._reset()
  127. def _reset(self):
  128. self._untraced = True
  129. self._tinfo = [] # handle -> TensorInfo
  130. self._seq = []
  131. self._pc = 0
  132. self._graph = None
  133. self._need_reset_nodes = None
  134. self._lazy_eval_graph = None
  135. self._lazy_eval_tensors = weakref.WeakSet()
  136. self._lazy_eval_links = None
  137. self._active_tensors = weakref.WeakSet()
  138. self._tensor_remaps = None
  139. self._inputs_to_restore = None
  140. self._arg_bindings = None
  141. self._kwarg_bindings = None
  142. self._output_bindings = None
  143. self._output_names = None
  144. def _new_handle(self):
  145. handle = len(self._tinfo)
  146. info = TensorInfo()
  147. self._tinfo.append(info)
  148. return handle, info
  149. def _apply_op(self, op, args):
  150. assert not self._untraced
  151. # check against trace
  152. if self._pc >= len(self._seq):
  153. raise TraceMismatchError("trace should end here, but more op observed")
  154. record = self._seq[self._pc]
  155. op_, ihandles, ohandles = record
  156. if op != op_:
  157. raise TraceMismatchError("op different from last time")
  158. if len(ihandles) != len(args):
  159. raise TraceMismatchError("op input size different from last time")
  160. for h, x in zip(ihandles, args):
  161. info = self._tinfo[h]
  162. if info.external:
  163. if (
  164. x.__class__ is CompiledTensorProxy
  165. and not self._tinfo[x._CompiledTensorProxy__handle].exported
  166. ):
  167. raise TraceMismatchError(
  168. "failed to capture: input was an external tensor "
  169. "last time, got an internal tensor this time"
  170. )
  171. if info.bound_data:
  172. if x.__class__ is CompiledTensorProxy:
  173. raise TraceMismatchError(
  174. "const capture violated: was an external tensor "
  175. "last time, got an internal tensor this time"
  176. )
  177. if x._handle != info.bound_data._handle:
  178. if not np.array_equal(x.numpy(), info.bound_data.numpy()):
  179. raise TraceMismatchError(
  180. "const capture violated: got "
  181. "a different tensor this time"
  182. )
  183. else:
  184. if info.dtype != x.dtype:
  185. raise TraceMismatchError(
  186. "failed to capture: different dtype from last time"
  187. )
  188. if info.device != x.device:
  189. raise TraceMismatchError(
  190. "failed to capture: different device from last time"
  191. )
  192. info.data_setter.set_value(x._dev_tensor())
  193. else:
  194. if x.__class__ is not CompiledTensorProxy:
  195. if x not in self._tensor_remaps:
  196. raise TraceMismatchError(
  197. "unexpected capture: trying to use an external tensor as "
  198. "input, but that input was an internal tensor last time"
  199. )
  200. else:
  201. x = self._tensor_remaps[x]
  202. if x._CompiledTensorProxy__handle != h:
  203. raise TraceMismatchError(
  204. "mis-wiring: input edge to an data flow "
  205. "graph node is different from last time"
  206. )
  207. self._pc += 1
  208. outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
  209. self._active_tensors.update(outputs)
  210. return outputs
  211. def _apply_const(self, op, args):
  212. assert not self._untraced
  213. # check against trace
  214. if self._pc >= len(self._seq):
  215. raise TraceMismatchError("trace should end here, but more op observed")
  216. record = self._seq[self._pc]
  217. op_, ihandles, ohandles = record
  218. assert isinstance(op_, Const)
  219. eq = op_.value == op.value
  220. if not isinstance(eq, bool):
  221. eq = all(eq)
  222. if not eq:
  223. raise TraceMismatchError(
  224. "const tensor violated: got a different tensor this time"
  225. )
  226. self._pc += 1
  227. (h,) = ohandles
  228. outputs = tuple([self._tinfo[h].bound_data])
  229. return outputs
  230. def _record_op(self, op, inputs, outputs):
  231. if skip_tracing:
  232. for x in inputs:
  233. h = getattr(x, "_TraceMixin__handle", None)
  234. if h is not None:
  235. self._tinfo[h].data_read = True
  236. return
  237. ihandles = []
  238. for x in inputs:
  239. h = getattr(x, "_TraceMixin__handle", None)
  240. if h is None or (not self._capture_as_const and self._tinfo[h].exported):
  241. h, info = self._new_handle()
  242. info.external = True
  243. info.device = x.device
  244. info.dtype = x.dtype
  245. info.shape = x.shape
  246. if self._capture_as_const:
  247. info.bound_data = x
  248. ihandles.append(h)
  249. ohandles = []
  250. for x in outputs:
  251. h, info = self._new_handle()
  252. ohandles.append(h)
  253. info.external = False
  254. TraceMixin._TraceMixin__inject(x, h)
  255. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  256. self._active_tensors.update(outputs)
  257. def _record_const(self, op, outputs):
  258. if skip_tracing:
  259. (x,) = outputs
  260. h = getattr(x, "_TraceMixin__handle", None)
  261. if h is not None:
  262. self._tinfo[h].data_read = True
  263. return
  264. (x,) = outputs
  265. h, info = self._new_handle()
  266. ohandles = [h]
  267. info.external = True
  268. info.device = x.device
  269. info.dtype = x.dtype
  270. info.shape = x.shape
  271. info.bound_data = x
  272. info.is_const = True
  273. TraceMixin._TraceMixin__inject(x, h)
  274. self._seq.append((op, tuple(), tuple(ohandles)))
  275. def _set_active(self, active: bool):
  276. global active_trace
  277. if active:
  278. if active_trace:
  279. raise NotImplementedError("sorry, not implemented: nested trace")
  280. active_trace = self
  281. else:
  282. assert active_trace is self
  283. active_trace = None
  284. def _init_trace(self, symbolic: bool):
  285. apply.enable(apply_with_tracing)
  286. apply.enable(apply_const_with_tracing)
  287. if symbolic:
  288. apply.enable(apply_symbolic_mode)
  289. apply.enable(apply_const_symbolic_mode)
  290. self._lazy_eval_graph = G.Graph()
  291. self._apply_graph_options(self._lazy_eval_graph)
  292. self._lazy_eval_links = ()
  293. def _take_escaped_tensors(self):
  294. escaped_tensors = tuple(self._active_tensors)
  295. self._active_tensors.clear()
  296. return escaped_tensors
  297. def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
  298. readers = [
  299. G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
  300. for x in lazy_eval_tensors
  301. ]
  302. self._apply_graph_options(lazy_eval_graph)
  303. # FIXME
  304. if self._graph_opt_level is not None:
  305. lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
  306. else:
  307. lazy_eval_graph.options.graph_opt_level = 2
  308. lazy_eval_graph.set_priority_to_id([*lazy_eval_links, *readers])
  309. lazy_eval_graph.compile(*lazy_eval_links, *readers)
  310. lazy_eval_graph()
  311. for r, x in zip(readers, lazy_eval_tensors):
  312. assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
  313. @contextlib.contextmanager
  314. def _setup(self):
  315. interrupted = False
  316. def do_enter():
  317. self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
  318. self._set_active(True)
  319. if self._untraced:
  320. self._init_trace(self._symbolic)
  321. else:
  322. apply.enable(apply_compiled_mode)
  323. apply.enable(apply_const_compiled_mode)
  324. if self._graph is None:
  325. self._compile()
  326. self._graph.execute()
  327. def do_finalize():
  328. escaped_tensors = self._take_escaped_tensors()
  329. if self._untraced:
  330. for x in escaped_tensors:
  331. info = self._tinfo[x._TraceMixin__handle]
  332. info.data_read = True
  333. x._TraceMixin__restore()
  334. if self._inputs_to_restore:
  335. for x in self._inputs_to_restore:
  336. x._TraceMixin__restore()
  337. if self._symbolic and (
  338. self._lazy_eval_tensors or self._lazy_eval_links
  339. ):
  340. # eval lazy eval tensors
  341. self._lazy_eval(
  342. self._lazy_eval_graph,
  343. tuple(self._lazy_eval_tensors),
  344. self._lazy_eval_links,
  345. )
  346. self._lazy_eval_graph = None
  347. self._lazy_eval_tensors = None
  348. self._lazy_eval_links = None
  349. self._untraced = False
  350. else:
  351. # compiled_tensor leaks
  352. if self._pc == len(self._seq):
  353. for x in escaped_tensors:
  354. try:
  355. assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
  356. except TraceMismatchError:
  357. # TraceMismatchError thrown in do_exit
  358. pass
  359. self._graph.wait()
  360. self._reset_exec_env()
  361. # reset status
  362. self._pc = 0
  363. self._tensor_remaps = None
  364. apply.disable(apply_with_tracing)
  365. apply.disable(apply_const_with_tracing)
  366. apply.disable(apply_symbolic_mode)
  367. apply.disable(apply_const_symbolic_mode)
  368. apply.disable(apply_compiled_mode)
  369. apply.disable(apply_const_compiled_mode)
  370. self._set_active(False)
  371. # Restore global variable
  372. set_symbolic_shape(self._save_symbolic_shape)
  373. def do_exit():
  374. if not self._untraced and self._pc != len(self._seq):
  375. raise TraceMismatchError("premature end")
  376. if not self._symbolic or not self._untraced:
  377. for x in self._active_tensors:
  378. x._dev_tensor()
  379. try:
  380. do_enter()
  381. yield
  382. do_exit()
  383. except:
  384. interrupted = True
  385. raise
  386. finally:
  387. do_finalize()
  388. if interrupted:
  389. self._reset()
  390. def _begin_excluded_region(self):
  391. if self._capture_as_const:
  392. raise RuntimeError(
  393. "exclude_from_trace cannot be used with capture_as_const"
  394. )
  395. if self._untraced:
  396. # conditionally reading a compiled tensor in excluded region
  397. # is permitted, so we have to assume every tensor might be read
  398. for x in self._active_tensors:
  399. info = self._tinfo[x._TraceMixin__handle]
  400. info.exported = True
  401. info.data_read = True
  402. def _apply_graph_options(self, graph):
  403. graph.options.no_force_inplace = True
  404. graph.options.seq_opt.enable_seq_comp_node_opt = False
  405. # graph opt level
  406. # if self._graph_opt_level is not None:
  407. # graph.options.graph_opt_level = self._graph_opt_level
  408. # FIXME
  409. graph.options.graph_opt_level = 0
  410. # sublinear
  411. if self._sublinear_memory_config is not None:
  412. graph.options.enable_sublinear_memory_opt = True
  413. sublinear_config = graph.options.sublinear_mem_config
  414. sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
  415. sublinear_config.genetic_nr_iter = (
  416. self._sublinear_memory_config.genetic_nr_iter
  417. )
  418. sublinear_config.genetic_pool_size = (
  419. self._sublinear_memory_config.genetic_pool_size
  420. )
  421. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  422. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  423. # profile
  424. if self._profiling:
  425. self._profiler = GraphProfiler(graph)
  426. def _compile(self):
  427. graph = self._graph = G.Graph()
  428. graph.options.async_exec_level = 0b100
  429. self._apply_graph_options(graph)
  430. # graph.options.graph_opt_level = 0
  431. need_reset_nodes = self._need_reset_nodes = []
  432. # links enforce ordering of I/O nodes
  433. in_out_links = ()
  434. io_links = ()
  435. readers = []
  436. if self._capture_as_const:
  437. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  438. info = self._tinfo[h]
  439. opnode = info.data_setter = G.InputNode(
  440. device=info.device,
  441. dtype=info.dtype,
  442. shape=info.shape or (1,),
  443. graph=graph,
  444. use_static_shape=_input_node_use_static_shape(),
  445. )
  446. need_reset_nodes.append(opnode)
  447. info.varnode = opnode.outputs[0]
  448. in_out_links += opnode.outputs[1:]
  449. for op, ihandles, ohandles in self._seq:
  450. if isinstance(op, Const):
  451. assert len(ihandles) == 0
  452. (h,) = ohandles
  453. info = self._tinfo[h]
  454. if not hasattr(info, "varnode"):
  455. assert info.external
  456. assert info.bound_data
  457. info.varnode = graph.make_const(
  458. info.bound_data.numpy(),
  459. info.bound_data.dtype,
  460. info.bound_data.device,
  461. )
  462. continue
  463. require_links = type(op) in _io_op_types
  464. ivars = []
  465. for i, h in enumerate(ihandles):
  466. info = self._tinfo[h]
  467. if not hasattr(info, "varnode"):
  468. assert info.external
  469. if info.bound_data:
  470. if hasattr(info, "is_const") and info.is_const:
  471. info.varnode = graph.make_const(
  472. info.bound_data.numpy(),
  473. info.bound_data.dtype,
  474. info.bound_data.device,
  475. )
  476. else:
  477. info.varnode = graph.make_const(
  478. info.bound_data._dev_tensor()
  479. # info.bound_data.numpy()
  480. )
  481. else:
  482. opnode = info.data_setter = G.InputNode(
  483. *in_out_links,
  484. device=info.device,
  485. dtype=info.dtype,
  486. shape=info.shape or (1,),
  487. graph=graph,
  488. use_static_shape=_input_node_use_static_shape(),
  489. )
  490. need_reset_nodes.append(opnode)
  491. info.varnode, *in_out_links = opnode.outputs
  492. if require_links and i == 0 and len(io_links) > 0:
  493. opnode = G.VirtualDepNode(
  494. [info.varnode, *io_links], str(io_links[0].device)
  495. )
  496. info.varnode = opnode.outputs[0]
  497. io_links = (info.varnode,)
  498. ivars.append(info.varnode)
  499. ovars = apply(op, *ivars)
  500. if require_links and len(ovars) > 0:
  501. io_links = (ovars[0],)
  502. assert len(ovars) == len(ohandles)
  503. for h, v in zip(ohandles, ovars):
  504. info = self._tinfo[h]
  505. info.varnode = v
  506. def add_reader(opnode):
  507. nonlocal in_out_links
  508. need_reset_nodes.append(opnode)
  509. readers.append(opnode.outputs[0])
  510. in_out_links = opnode.outputs
  511. if info.data_read:
  512. # Shape can be obtained from data so doesn't need its own
  513. # output node. On the other hand, value is read separately
  514. # to leverage eager h2d copy
  515. info.shape_read = False
  516. opnode = info.data_reader = G.OutputNode(v, *in_out_links)
  517. add_reader(opnode)
  518. if info.value_read:
  519. opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
  520. add_reader(opnode)
  521. if info.shape_read:
  522. opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
  523. add_reader(opnode)
  524. # FIXME
  525. if self._graph_opt_level is not None:
  526. graph.options.graph_opt_level = self._graph_opt_level
  527. else:
  528. graph.options.graph_opt_level = 2
  529. graph.set_priority_to_id([*readers, *in_out_links, *io_links])
  530. graph.compile(*readers, *in_out_links, *io_links)
  531. def _reset_exec_env(self):
  532. for opnode in self._need_reset_nodes:
  533. opnode.reset()
  534. def _require_shape(self, handle):
  535. info = self._tinfo[handle]
  536. info.shape_read = True
  537. def _require_value(self, handle):
  538. info = self._tinfo[handle]
  539. info.value_read = True
  540. def _require_data(self, handle):
  541. info = self._tinfo[handle]
  542. info.data_read = True
  543. def __call__(self, *args, **kwargs):
  544. if is_tracing():
  545. return self.__wrapped__(*args, **kwargs)
  546. with self._setup():
  547. if self._capture_as_const:
  548. self._process_inputs(*args, **kwargs)
  549. outputs = self.__wrapped__(*args, **kwargs)
  550. if self._capture_as_const:
  551. self._process_outputs(outputs)
  552. return outputs
  553. def dump(
  554. self,
  555. file,
  556. *,
  557. arg_names=None,
  558. output_names=None,
  559. append=False,
  560. optimize_for_inference=True,
  561. **kwargs
  562. ):
  563. r"""
  564. Serializes trace to file system.
  565. :param file: output file, could be file object or filename.
  566. :param arg_names: names of the input tensors in the traced function.
  567. :param output_names: names of the output tensors in the traced function,
  568. use the default name if not specified.
  569. :param append: whether output is appended to ``file``.
  570. Only works when ``file`` is str.
  571. :param optimize_for_inference: enbale optmizations,
  572. will skip all optimize options if this is False. Default: True
  573. :Keyword Arguments:
  574. * enable_io16xc32 --
  575. whether to use float16 for I/O between oprs and use
  576. float32 as internal computation precision. Note the output var would be
  577. changed to float16.
  578. * enable_ioc16 --
  579. whether to use float16 for both I/O and computation
  580. precision.
  581. * enable_hwcd4 --
  582. whether to use NHWCD4 data layout. This is faster on some
  583. OpenCL backend.
  584. * enable_nchw88 --
  585. whether to use NCHW88 data layout, currently
  586. used in X86 AVX backend.
  587. * enable_nchw44 --
  588. whether to use NCHW44 data layout, currently
  589. used in arm backend.
  590. * enable_nchw44_dot --
  591. whether to use NCHW44_dot data layout, currently
  592. used in armv8.2+dotprod backend.
  593. * enable_nchw4 --
  594. whether to use NCHW4 data layout, currently
  595. used in nvidia backend(based on cudnn).
  596. * enable_nchw32 --
  597. whether to use NCHW32 data layout, currently
  598. used in nvidia backend with tensorcore(based on cudnn).
  599. * enable_chwn4 --
  600. whether to use CHWN4 data layout, currently
  601. used in nvidia backend with tensorcore.
  602. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  603. into one opr.
  604. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  605. input for inference on nvidia backend(this optimization pass will
  606. result in mismatch of the precision of output of training and
  607. inference)
  608. """
  609. if not self._capture_as_const:
  610. raise ValueError(
  611. "you must specify capture_as_const=True at __init__ to use dump"
  612. )
  613. if self._untraced:
  614. raise RuntimeError("should run at least once before calling dump")
  615. if self._output_names and output_names:
  616. raise TypeError(
  617. "cannot specify output_names when output is already in dict format"
  618. )
  619. if output_names and not isinstance(output_names, collections.abc.Sequence):
  620. output_names = (output_names,)
  621. if output_names and len(output_names) != len(self._output_bindings):
  622. raise ValueError(
  623. "wrong number of output_names, should be {} values".format(
  624. len(self._output_bindings)
  625. )
  626. )
  627. if arg_names is None:
  628. arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
  629. if arg_names and not isinstance(arg_names, collections.abc.Sequence):
  630. arg_names = (arg_names,)
  631. if arg_names and len(arg_names) != len(self._arg_bindings):
  632. raise ValueError(
  633. "wrong number of arg_names, should be {} values".format(
  634. len(self._arg_bindings)
  635. )
  636. )
  637. output_names = output_names or self._output_names
  638. dumped_device = as_device("xpux")
  639. h2v = {}
  640. graph = G.Graph()
  641. # only graph_opt_level takes effect in dump
  642. self._apply_graph_options(graph)
  643. for i, h in enumerate(self._arg_bindings):
  644. info = self._tinfo[h]
  645. h2v[h] = graph.make_h2d(
  646. dtype=info.dtype,
  647. device=dumped_device,
  648. shape=info.shape or (1,),
  649. name=arg_names[i] if arg_names else None,
  650. )
  651. for k, h in self._kwarg_bindings.items():
  652. info = self._tinfo[h]
  653. h2v[h] = graph.make_h2d(
  654. dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
  655. )
  656. for op, ihandles, ohandles in self._seq:
  657. if isinstance(op, Const):
  658. assert len(ihandles) == 0
  659. (h,) = ohandles
  660. info = self._tinfo[h]
  661. if h not in h2v:
  662. assert info.external
  663. assert info.bound_data
  664. h2v[h] = graph.make_const(
  665. info.bound_data.numpy(), dtype=info.dtype, device=info.device,
  666. )
  667. continue
  668. ivars = []
  669. for h in ihandles:
  670. info = self._tinfo[h]
  671. if h not in h2v:
  672. assert info.external
  673. assert info.bound_data
  674. h2v[h] = graph.make_const(
  675. info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
  676. )
  677. ivars.append(h2v[h])
  678. ovars = apply(op, *ivars)
  679. assert len(ovars) == len(ohandles)
  680. h2v.update(zip(ohandles, ovars))
  681. dest_vars = []
  682. for i, h in enumerate(self._output_bindings):
  683. v = h2v[h]
  684. if output_names:
  685. v.name = output_names[i]
  686. dest_vars.append(v)
  687. if optimize_for_inference:
  688. dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
  689. if isinstance(file, str):
  690. permission = "wb" if append == False else "ab"
  691. file = open(file, permission)
  692. dump_content, dump_info = G.dump_graph(dest_vars)
  693. file.write(dump_content)
  694. return dump_info
  695. def _process_inputs(self, *args, **kwargs):
  696. if self._untraced:
  697. self._inputs_to_restore = []
  698. def record_input(x):
  699. if x is None:
  700. return
  701. h, info = self._new_handle()
  702. info.external = False
  703. info.device = x.device
  704. info.dtype = x.dtype
  705. info.shape = x.shape
  706. TraceMixin._TraceMixin__inject(x, h)
  707. self._inputs_to_restore.append(x)
  708. return h
  709. self._arg_bindings = []
  710. for i, x in enumerate(args):
  711. x = find_raw_tensor(x)
  712. if x is None:
  713. raise TypeError(
  714. "positional arguments should all be tensor "
  715. "but args[%d] cannot be recognized as one" % i
  716. )
  717. self._arg_bindings.append(record_input(x))
  718. self._kwarg_bindings = {}
  719. for k, x in kwargs.items():
  720. x = find_raw_tensor(x)
  721. if x is not None:
  722. self._kwarg_bindings[k] = record_input(x)
  723. else:
  724. if len(args) != len(self._arg_bindings):
  725. raise TraceMismatchError("positional argument length mismatch")
  726. self._tensor_remaps = {}
  727. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  728. x = find_raw_tensor(x)
  729. if x is None:
  730. raise TypeError(
  731. "positional arguments should all be tensor "
  732. "but args[%d] cannot be recognized as one" % i
  733. )
  734. info = self._tinfo[h]
  735. if x.dtype != info.dtype:
  736. raise TypeError("args[%d].dtype different from last time" % i)
  737. if x.device != info.device:
  738. raise TypeError("args[%d].device different from last time" % i)
  739. info.data_setter.set_value(x._dev_tensor())
  740. self._tensor_remaps[x] = CompiledTensorProxy(h)
  741. kwargs_tensors = {}
  742. for k, x in kwargs.items():
  743. x = find_raw_tensor(x)
  744. if x is not None:
  745. kwargs_tensors[k] = x
  746. if set(kwargs_tensors) != set(self._kwarg_bindings):
  747. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  748. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  749. if too_many:
  750. raise TraceMismatchError(
  751. "keyword arguments found to be tensor this time "
  752. "but were non-tensor previously: %s" % " ".join(too_many)
  753. )
  754. if too_few:
  755. raise TraceMismatchError(
  756. "keyword arguments found to be non-tensor this time "
  757. "but were tensor previously: %s" % " ".join(too_few)
  758. )
  759. for k, h in self._kwarg_bindings.items():
  760. x = kwargs_tensors[k]
  761. info = self._tinfo[h]
  762. if x.dtype != info.dtype:
  763. raise TypeError("kwargs[%s].dtype different from last time" % k)
  764. if x.device != info.device:
  765. raise TypeError("kwargs[%s].device different from last time" % k)
  766. info.data_setter.set_value(x._dev_tensor())
  767. self._tensor_remaps[x] = CompiledTensorProxy(h)
  768. def _process_outputs(self, outputs):
  769. output_names = None
  770. if isinstance(outputs, collections.abc.Mapping):
  771. output_names, outputs = zip(*sorted(outputs.items()))
  772. elif not isinstance(outputs, collections.abc.Sequence):
  773. outputs = (outputs,)
  774. if not self._untraced:
  775. if output_names != self._output_names:
  776. too_many = set(output_names) - set(self._output_names)
  777. too_few = set(self._output_names) - set(output_names)
  778. if too_many:
  779. raise TraceMismatchError(
  780. "output has more keys than last time: %s" % " ".join(too_many)
  781. )
  782. if too_few:
  783. raise TraceMismatchError(
  784. "output has less keys than last time: %s" % " ".join(too_few)
  785. )
  786. if len(outputs) != len(self._output_bindings):
  787. raise TraceMismatchError("output size differs from last time")
  788. else:
  789. self._output_names = output_names
  790. self._output_bindings = []
  791. for i, x in enumerate(outputs):
  792. x = find_raw_tensor(x)
  793. if x is None:
  794. raise TypeError("every item of return value should be tensor")
  795. if self._untraced:
  796. if not isinstance(x, TraceMixin):
  797. raise RuntimeError("output is not computed from inputs")
  798. h = x._TraceMixin__handle
  799. self._output_bindings.append(h)
  800. else:
  801. if not isinstance(x, CompiledTensorProxy):
  802. raise RuntimeError("output is not computed from inputs")
  803. h = x._CompiledTensorProxy__handle
  804. if h != self._output_bindings[i]:
  805. raise TraceMismatchError(
  806. "retval[%s] is a different tensor than last time"
  807. % (output_names and output_names[i] or i)
  808. )
  809. def get_profile(self):
  810. """
  811. Get profiling result for compiled trace.
  812. :return: a json compatible object.
  813. """
  814. if not self._profiler:
  815. raise RuntimeError("trace is not set with profiling=True")
  816. return json.loads(self._profiler.get())
  817. def trace(self, *args, **kwargs):
  818. raise NotImplementedError(
  819. "trace is deemed unbeneficial with the new "
  820. "tracing mechanism. You should alwasy use __call__."
  821. )
  822. class CompiledTensorProxy(RawTensor):
  823. """
  824. Duck-typed RawTensor
  825. """
  826. def __init__(self, handle):
  827. self.__handle = handle
  828. self._isscalar = False
  829. self.__info = active_trace._tinfo[handle]
  830. self.__shape = None
  831. self.__data = None
  832. self.__value = None
  833. @property
  834. def dtype(self):
  835. return self.__info.varnode.dtype
  836. @property
  837. def device(self):
  838. return self.__info.varnode.device
  839. @property
  840. def shape(self):
  841. if self._isscalar:
  842. return ()
  843. if self.__shape is None:
  844. if self.__info.shape_read:
  845. self.__shape = self.__info.shape_reader.get_value().shape
  846. elif self.__info.data_read:
  847. self.__shape = self._dev_tensor().shape
  848. else:
  849. raise TraceMismatchError("shape of this tensor is not read in trace")
  850. return self.__shape
  851. def numpy(self):
  852. if self.__value is None:
  853. if self.__info.value_read:
  854. self.__value = self.__info.value_reader.get_value()
  855. elif self.__info.data_read:
  856. self.__value = self._dev_tensor().numpy()
  857. else:
  858. raise TraceMismatchError("value of this tensor is not read in trace")
  859. if self._isscalar:
  860. self.__value = self.__value.squeeze()
  861. return self.__value
  862. def _dev_tensor(self):
  863. if self.__data is None:
  864. if not self.__info.data_read:
  865. raise TraceMismatchError("raw data of this tensor is not read in trace")
  866. self.__data = self.__info.data_reader.get_value()
  867. return self.__data
  868. def _drop(self):
  869. return
  870. def _swap_in(self):
  871. return
  872. def _swap_out(self):
  873. return
  874. def __del__(self):
  875. if self.__info.shape_read and self.__shape is not None:
  876. self.__info.shape_reader.drop_value()
  877. if self.__info.value_read and self.__value is not None:
  878. self.__info.value_reader.drop_value()
  879. if self.__info.data_read and self.__data is not None:
  880. self.__info.data_reader.drop_value()
  881. class LazyEvalTensor(RawTensor):
  882. def __init__(self, varnode, isscalar=False):
  883. super().__init__()
  884. self.__varnode = varnode
  885. self._isscalar = isscalar
  886. @property
  887. def dtype(self):
  888. return self.__varnode.dtype
  889. @property
  890. def device(self):
  891. return self.__varnode.device
  892. @property
  893. def shape(self):
  894. if self._isscalar:
  895. return ()
  896. return self.__varnode.shape
  897. def numpy(self):
  898. ret = self.__varnode.value
  899. if self._isscalar:
  900. ret = ret.squeeze()
  901. return ret
  902. def _drop(self):
  903. return
  904. def _swap_in(self):
  905. return
  906. def _swap_out(self):
  907. return
  908. def _dev_tensor(self):
  909. raise RuntimeError("cannot access data during symbolic tracing")
  910. class TraceMixin:
  911. __subclass_cache = {}
  912. def __inject(self, handle):
  913. cache = __class__.__subclass_cache
  914. cls = self.__class__
  915. subcls = cache.get(cls)
  916. if subcls is None:
  917. subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
  918. self.__class__ = subcls
  919. self.__handle = handle
  920. self.__cls = cls
  921. return self
  922. def __restore(self):
  923. cls = self.__cls
  924. del self.__handle
  925. del self.__cls
  926. self.__class__ = cls
  927. return self
  928. @property
  929. def shape(self):
  930. if not skip_tracing:
  931. active_trace._require_shape(self.__handle)
  932. return super().shape
  933. def numpy(self):
  934. if not skip_tracing:
  935. active_trace._require_value(self.__handle)
  936. return super().numpy()
  937. def _dev_tensor(self):
  938. if not skip_tracing:
  939. active_trace._require_data(self.__handle)
  940. return super()._dev_tensor()
  941. def _drop(self):
  942. return
  943. def _swap_in(self):
  944. return
  945. def _swap_out(self):
  946. return
  947. class TracedRawTensor(TraceMixin, RawTensor):
  948. pass
  949. class TracedLazyTensor(TraceMixin, LazyEvalTensor):
  950. pass
  951. def assign_raw_tensor(lhs, rhs):
  952. handle = rhs._handle
  953. # Keep isscalar of lhs
  954. isscalar = lhs._isscalar
  955. rhs.__dict__.clear()
  956. lhs.__dict__.clear()
  957. lhs.__class__ = RawTensor
  958. lhs.__init__(handle, isscalar=isscalar)
  959. # this hook turns RawTensor into LazyEvalTensor
  960. @apply.register()
  961. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  962. graph = active_trace._lazy_eval_graph
  963. ivars = []
  964. for x in args:
  965. var = getattr(x, "_LazyEvalTensor__varnode", None)
  966. if var:
  967. ivars.append(var)
  968. else:
  969. data_setter = G.InputNode(
  970. device=x.device,
  971. dtype=x.dtype,
  972. shape=x.shape or (1,),
  973. graph=graph,
  974. use_static_shape=True,
  975. )
  976. var = data_setter.outputs[0]
  977. ivars.append(var)
  978. data_setter.set_value(x._dev_tensor())
  979. require_links = type(op) in _io_op_types
  980. if require_links and active_trace._lazy_eval_links:
  981. assert len(ivars) > 0, "op should has at least one input"
  982. opnode = G.VirtualDepNode(
  983. [ivars[0], *active_trace._lazy_eval_links],
  984. str(active_trace._lazy_eval_links[0].device),
  985. )
  986. ivars[0] = opnode.outputs[0]
  987. active_trace._lazy_eval_links = (ivars[0],)
  988. ovars = apply(op, *ivars)
  989. if require_links:
  990. active_trace._lazy_eval_links = (ovars[0],)
  991. outputs = [LazyEvalTensor(v) for v in ovars]
  992. active_trace._lazy_eval_tensors.update(outputs)
  993. return outputs
  994. apply.disable(apply_symbolic_mode)
  995. @apply.register()
  996. def apply_const_symbolic_mode(op: Const, *args: RawTensor):
  997. graph = active_trace._lazy_eval_graph
  998. ret = LazyEvalTensor(
  999. graph.make_const(op.value, dtype=op.dtype, device=op.device), isscalar=True
  1000. )
  1001. active_trace._lazy_eval_tensors.add(ret)
  1002. return (ret,)
  1003. apply.disable(apply_const_symbolic_mode)
  1004. @apply.register()
  1005. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  1006. if skip_tracing:
  1007. args = [
  1008. as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  1009. for x in args
  1010. ]
  1011. return apply.super(op, *args)
  1012. return active_trace._apply_op(op, args)
  1013. apply.disable(apply_compiled_mode)
  1014. @apply.register()
  1015. def apply_const_compiled_mode(op: Const, *args: RawTensor):
  1016. if skip_tracing:
  1017. args = [
  1018. as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  1019. for x in args
  1020. ]
  1021. return apply.super(op, *args)
  1022. return active_trace._apply_const(op, args)
  1023. apply.disable(apply_const_compiled_mode)
  1024. # this hook injects TraceMixin
  1025. @apply.register()
  1026. def apply_with_tracing(op: OpDef, *args: RawTensor):
  1027. outputs = apply.super(op, *args)
  1028. active_trace._record_op(op, args, outputs)
  1029. return outputs
  1030. apply.disable(apply_with_tracing)
  1031. @apply.register()
  1032. def apply_const_with_tracing(op: Const, *args: RawTensor):
  1033. outputs = apply.super(op, *args)
  1034. active_trace._record_const(op, outputs)
  1035. return outputs
  1036. apply.disable(apply_const_with_tracing)
  1037. class BrokenRawTensor(RawTensor):
  1038. def __getattribute__(self, _):
  1039. raise RuntimeError("broken due to misuse of tracing")
  1040. def __setattr__(self, *_):
  1041. raise RuntimeError("broken due to misuse of tracing")
  1042. @functools.singledispatch
  1043. def find_raw_tensor(x):
  1044. return None
  1045. @find_raw_tensor.register(RawTensor)
  1046. def _(x):
  1047. return x
  1048. @find_raw_tensor.register(TensorWrapperBase)
  1049. def _(x):
  1050. x = getattr(x, "__wrapped__", None)
  1051. if x is not None:
  1052. return find_raw_tensor(x)
  1053. @find_raw_tensor.register(Tensor)
  1054. def _(x):
  1055. x = getattr(x, "_data", None)
  1056. if x is not None:
  1057. return find_raw_tensor(x)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台