You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 36 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import contextlib
  11. import functools
  12. import itertools
  13. import json
  14. import typing
  15. import warnings
  16. import weakref
  17. import numpy as np
  18. from ..core._imperative_rt import GraphProfiler
  19. from ..core._imperative_rt.ops import OprAttr
  20. from ..core._trace_option import set_tensor_shape
  21. from ..core.ops.special import Const
  22. from ..core.tensor import megbrain_graph as G
  23. from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
  24. from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
  25. from ..core.tensor.tensor import Tensor
  26. from .sublinear_memory_config import SublinearMemoryConfig
  27. class TraceMismatchError(RuntimeError):
  28. pass
  29. active_trace = None
  30. skip_tracing = False
  31. @contextlib.contextmanager
  32. def exclude_from_trace():
  33. global skip_tracing
  34. if skip_tracing:
  35. yield
  36. return
  37. try:
  38. skip_tracing = True
  39. if active_trace is not None:
  40. active_trace._begin_excluded_region()
  41. yield
  42. finally:
  43. skip_tracing = False
  44. class TensorInfo:
  45. __slots__ = (
  46. # collected attributes
  47. "external",
  48. "exported",
  49. "data_read",
  50. "shape_read",
  51. "value_read",
  52. "device",
  53. "dtype",
  54. "shape",
  55. "bound_data",
  56. # resources for execution
  57. "varnode",
  58. "data_setter",
  59. "shape_reader",
  60. "value_reader",
  61. "data_reader",
  62. )
  63. def __init__(self):
  64. self.exported = None
  65. self.data_read = None
  66. self.shape_read = None
  67. self.value_read = None
  68. self.bound_data = None
  69. self.data_setter = None
  70. self.shape_reader = None
  71. self.value_reader = None
  72. self.data_reader = None
  73. class trace:
  74. """
  75. Wraps a callable and provide:
  76. * tracing via :meth:`.trace` and :meth:`.dump`
  77. * accelerated evalutaion via :meth:`.__call__`
  78. :param function: the function will be traced.
  79. :param symbolic: whether to apply symbolic execution for tracing. Default: False
  80. :param capture_as_const: capture global vars or closures as const value. Default: False
  81. :param sublinear_memory_config: configuration for sublinear memory optimization.
  82. If not None, it enables sublinear memory optimization with given setting.
  83. :param profiling: whether to profile compiled trace. Default: False
  84. :param opt_level: optimization level for compiling trace.
  85. :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
  86. """
  87. def __new__(cls, *args, **kwargs):
  88. if not args:
  89. return functools.partial(cls, **kwargs)
  90. return super().__new__(cls)
  91. def __init__(
  92. self,
  93. function,
  94. symbolic=False,
  95. capture_as_const=False,
  96. sublinear_memory_config: SublinearMemoryConfig = None,
  97. profiling: bool = False,
  98. opt_level: int = None,
  99. tensor_shape: bool = True,
  100. ):
  101. self.__wrapped__ = function
  102. self._symbolic = symbolic
  103. self._capture_as_const = capture_as_const
  104. self._sublinear_memory_config = sublinear_memory_config
  105. self._profiling = profiling
  106. self._profiler = None
  107. self._graph_opt_level = opt_level
  108. self._tensor_shape = tensor_shape
  109. self._reset()
  110. def _reset(self):
  111. self._untraced = True
  112. self._tinfo = [] # handle -> TensorInfo
  113. self._seq = []
  114. self._pc = 0
  115. self._graph = None
  116. self._need_reset_nodes = None
  117. self._lazy_eval_graph = None
  118. self._lazy_eval_tensors = []
  119. self._lazy_eval_tensor_count = 0
  120. self._active_tensors = weakref.WeakSet()
  121. self._tensor_remaps = None
  122. self._inputs_to_restore = None
  123. self._arg_bindings = None
  124. self._kwarg_bindings = None
  125. self._output_bindings = None
  126. self._output_names = None
  127. set_tensor_shape(self._tensor_shape)
  128. def _new_handle(self):
  129. handle = len(self._tinfo)
  130. info = TensorInfo()
  131. self._tinfo.append(info)
  132. return handle, info
  133. def _apply_op(self, op, args):
  134. assert not self._untraced
  135. # check against trace
  136. if self._pc >= len(self._seq):
  137. raise TraceMismatchError("trace should end here, but more op observed")
  138. record = self._seq[self._pc]
  139. op_, ihandles, ohandles = record
  140. if op != op_:
  141. # FIXME: will be removed once better rng implementation is done
  142. if isinstance(op, OprAttr) and (
  143. op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type
  144. ):
  145. if op.param[8:] != op_.param[8:]:
  146. raise TraceMismatchError("op different from last time")
  147. else:
  148. raise TraceMismatchError("op different from last time")
  149. if len(ihandles) != len(args):
  150. raise TraceMismatchError("op input size different from last time")
  151. for h, x in zip(ihandles, args):
  152. info = self._tinfo[h]
  153. if info.external:
  154. if (
  155. x.__class__ is CompiledTensorProxy
  156. and not self._tinfo[x._CompiledTensorProxy__handle].exported
  157. ):
  158. raise TraceMismatchError(
  159. "failed to capture: input was an external tensor "
  160. "last time, got an internal tensor this time"
  161. )
  162. if info.bound_data:
  163. if x.__class__ is CompiledTensorProxy:
  164. raise TraceMismatchError(
  165. "const capture violated: was an external tensor "
  166. "last time, got an internal tensor this time"
  167. )
  168. if x._handle != info.bound_data._handle:
  169. if not np.array_equal(x.numpy(), info.bound_data.numpy()):
  170. raise TraceMismatchError(
  171. "const capture violated: got "
  172. "a different tensor this time"
  173. )
  174. else:
  175. if info.dtype != x.dtype:
  176. raise TraceMismatchError(
  177. "failed to capture: different dtype from last time"
  178. )
  179. if info.device != x.device:
  180. raise TraceMismatchError(
  181. "failed to capture: different device from last time"
  182. )
  183. info.data_setter.set_value(x._dev_tensor())
  184. else:
  185. if x.__class__ is not CompiledTensorProxy:
  186. if x not in self._tensor_remaps:
  187. raise TraceMismatchError(
  188. "unexpected capture: trying to use an external tensor as "
  189. "input, but that input was an internal tensor last time"
  190. )
  191. else:
  192. x = self._tensor_remaps[x]
  193. if x._CompiledTensorProxy__handle != h:
  194. raise TraceMismatchError(
  195. "mis-wiring: input edge to an data flow "
  196. "graph node is different from last time"
  197. )
  198. self._pc += 1
  199. outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
  200. self._active_tensors.update(outputs)
  201. return outputs
  202. def _record_op(self, op, inputs, outputs):
  203. if skip_tracing:
  204. for x in inputs:
  205. h = getattr(x, "_TraceMixin__handle", None)
  206. if h is not None:
  207. self._tinfo[h].data_read = True
  208. return
  209. ihandles = []
  210. for x in inputs:
  211. h = getattr(x, "_TraceMixin__handle", None)
  212. if h is None or (not self._capture_as_const and self._tinfo[h].exported):
  213. h, info = self._new_handle()
  214. info.external = True
  215. info.device = x.device
  216. info.dtype = x.dtype
  217. info.shape = x.shape
  218. if self._capture_as_const:
  219. info.bound_data = x
  220. ihandles.append(h)
  221. ohandles = []
  222. for x in outputs:
  223. h, info = self._new_handle()
  224. ohandles.append(h)
  225. info.external = False
  226. TraceMixin._TraceMixin__inject(x, h)
  227. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  228. self._active_tensors.update(outputs)
  229. def _record_const(self, op, outputs):
  230. pass
  231. def _set_active(self, active: bool):
  232. global active_trace
  233. if active:
  234. if active_trace:
  235. raise NotImplementedError("sorry, not implemented: nested trace")
  236. active_trace = self
  237. else:
  238. assert active_trace is self
  239. active_trace = None
  240. def _init_trace(self, symbolic: bool):
  241. apply.enable(apply_with_tracing)
  242. apply.enable(apply_const_with_tracing)
  243. if symbolic:
  244. apply.enable(apply_symbolic_mode)
  245. apply.enable(apply_const_symbolic_mode)
  246. self._lazy_eval_graph = G.Graph()
  247. def _take_escaped_tensors(self):
  248. escaped_tensors = tuple(self._active_tensors)
  249. self._active_tensors.clear()
  250. return escaped_tensors
  251. def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors):
  252. active_lazy_eval_tensors = []
  253. visited = set()
  254. readers = []
  255. for x in lazy_eval_tensors:
  256. x = x()
  257. if x is None or x in visited:
  258. continue
  259. reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
  260. readers.append(reader)
  261. active_lazy_eval_tensors.append(x)
  262. visited.add(x)
  263. self._apply_graph_options(lazy_eval_graph)
  264. lazy_eval_graph.compile(*readers)
  265. lazy_eval_graph()
  266. for r, x in zip(readers, active_lazy_eval_tensors):
  267. assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
  268. @contextlib.contextmanager
  269. def _setup(self):
  270. interrupted = False
  271. def do_enter():
  272. self._set_active(True)
  273. if self._untraced:
  274. self._init_trace(self._symbolic)
  275. else:
  276. apply.enable(apply_compiled_mode)
  277. if self._graph is None:
  278. self._compile()
  279. self._graph.execute()
  280. def do_finalize():
  281. escaped_tensors = self._take_escaped_tensors()
  282. if self._untraced:
  283. for x in escaped_tensors:
  284. info = self._tinfo[x._TraceMixin__handle]
  285. info.data_read = True
  286. x._TraceMixin__restore()
  287. if self._inputs_to_restore:
  288. for x in self._inputs_to_restore:
  289. x._TraceMixin__restore()
  290. if self._symbolic and self._lazy_eval_tensors:
  291. # eval lazy eval tensors
  292. self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors)
  293. self._lazy_eval_graph = None
  294. self._lazy_eval_tensors = None
  295. self._untraced = False
  296. else:
  297. # compiled_tensor leaks
  298. if self._pc == len(self._seq):
  299. for x in escaped_tensors:
  300. try:
  301. assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
  302. except TraceMismatchError:
  303. # TraceMismatchError thrown in do_exit
  304. pass
  305. self._graph.wait()
  306. self._reset_exec_env()
  307. # reset status
  308. self._pc = 0
  309. self._tensor_remaps = None
  310. apply.disable(apply_with_tracing)
  311. apply.disable(apply_const_with_tracing)
  312. apply.disable(apply_symbolic_mode)
  313. apply.disable(apply_const_symbolic_mode)
  314. apply.disable(apply_compiled_mode)
  315. self._set_active(False)
  316. def do_exit():
  317. if not self._untraced and self._pc != len(self._seq):
  318. raise TraceMismatchError("premature end")
  319. if not self._symbolic or not self._untraced:
  320. for x in self._active_tensors:
  321. x._dev_tensor()
  322. try:
  323. do_enter()
  324. yield
  325. do_exit()
  326. except:
  327. interrupted = True
  328. raise
  329. finally:
  330. do_finalize()
  331. if interrupted:
  332. self._reset()
  333. def _begin_excluded_region(self):
  334. if self._capture_as_const:
  335. raise RuntimeError(
  336. "exclude_from_trace cannot be used with capture_as_const"
  337. )
  338. if self._untraced:
  339. # conditionally reading a compiled tensor in excluded region
  340. # is permitted, so we have to assume every tensor might be read
  341. for x in self._active_tensors:
  342. info = self._tinfo[x._TraceMixin__handle]
  343. info.exported = True
  344. info.data_read = True
  345. def _apply_graph_options(self, graph):
  346. graph.options.seq_opt.enable_seq_comp_node_opt = False
  347. # graph opt level
  348. if self._graph_opt_level is not None:
  349. graph.options.graph_opt_level = self._graph_opt_level
  350. # sublinear
  351. if self._sublinear_memory_config is not None:
  352. graph.options.enable_sublinear_memory_opt = True
  353. sublinear_config = graph.options.sublinear_mem_config
  354. sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
  355. sublinear_config.genetic_nr_iter = (
  356. self._sublinear_memory_config.genetic_nr_iter
  357. )
  358. sublinear_config.genetic_pool_size = (
  359. self._sublinear_memory_config.genetic_pool_size
  360. )
  361. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  362. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  363. # profile
  364. if self._profiling:
  365. self._profiler = GraphProfiler(graph)
  366. def _compile(self):
  367. graph = self._graph = G.Graph()
  368. graph.options.no_force_inplace = True
  369. graph.options.async_exec_level = 0b100
  370. self._apply_graph_options(graph)
  371. # graph.options.graph_opt_level = 0
  372. need_reset_nodes = self._need_reset_nodes = []
  373. # links enforce ordering of I/O nodes
  374. links = ()
  375. readers = []
  376. if self._capture_as_const:
  377. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  378. info = self._tinfo[h]
  379. opnode = info.data_setter = G.InputNode(
  380. device=info.device, dtype=info.dtype, shape=info.shape, graph=graph
  381. )
  382. need_reset_nodes.append(opnode)
  383. info.varnode = opnode.outputs[0]
  384. links += opnode.outputs[1:]
  385. for op, ihandles, ohandles in self._seq:
  386. ivars = []
  387. for h in ihandles:
  388. info = self._tinfo[h]
  389. if not hasattr(info, "varnode"):
  390. assert info.external
  391. if info.bound_data:
  392. info.varnode = graph.make_const(info.bound_data._dev_tensor())
  393. else:
  394. opnode = info.data_setter = G.InputNode(
  395. *links,
  396. device=info.device,
  397. dtype=info.dtype,
  398. shape=info.shape,
  399. graph=graph,
  400. )
  401. need_reset_nodes.append(opnode)
  402. info.varnode, *links = opnode.outputs
  403. ivars.append(info.varnode)
  404. ovars = apply(op, *ivars)
  405. assert len(ovars) == len(ohandles)
  406. for h, v in zip(ohandles, ovars):
  407. info = self._tinfo[h]
  408. info.varnode = v
  409. def add_reader(opnode):
  410. nonlocal links
  411. need_reset_nodes.append(opnode)
  412. readers.append(opnode.outputs[0])
  413. links = opnode.outputs
  414. if info.data_read:
  415. # Shape can be obtained from data so doesn't need its own
  416. # output node. On the other hand, value is read separately
  417. # to leverage eager h2d copy
  418. info.shape_read = False
  419. opnode = info.data_reader = G.OutputNode(v, *links)
  420. add_reader(opnode)
  421. if info.value_read:
  422. opnode = info.value_reader = G.ValueOutputNode(v, *links)
  423. add_reader(opnode)
  424. if info.shape_read:
  425. opnode = info.shape_reader = G.AttrOutputNode(v, *links)
  426. add_reader(opnode)
  427. graph.compile(*readers)
  428. def _reset_exec_env(self):
  429. for opnode in self._need_reset_nodes:
  430. opnode.reset()
  431. def _require_shape(self, handle):
  432. info = self._tinfo[handle]
  433. info.shape_read = True
  434. def _require_value(self, handle):
  435. info = self._tinfo[handle]
  436. info.value_read = True
  437. def _require_data(self, handle):
  438. info = self._tinfo[handle]
  439. info.data_read = True
  440. def __call__(self, *args, **kwargs):
  441. with self._setup():
  442. if self._capture_as_const:
  443. self._process_inputs(*args, **kwargs)
  444. outputs = self.__wrapped__(*args, **kwargs)
  445. if self._capture_as_const:
  446. self._process_outputs(outputs)
  447. return outputs
  448. def dump(
  449. self,
  450. file,
  451. *,
  452. arg_names=None,
  453. output_names=None,
  454. append=False,
  455. optimize_for_inference=True,
  456. **kwargs
  457. ):
  458. r"""Serializes trace to file system.
  459. :param file: output file, could be file object or filename.
  460. :param arg_names: names of the input tensors in the traced function.
  461. :param output_names: names of the output tensors in the traced function,
  462. use the default name if not specified.
  463. :param append: whether output is appended to ``file``.
  464. Only works when ``file`` is str.
  465. :param optimize_for_inference: enbale optmizations,
  466. will skip all optimize options if this is False. Default: True
  467. :Keyword Arguments:
  468. * enable_io16xc32 --
  469. whether to use float16 for I/O between oprs and use
  470. float32 as internal computation precision. Note the output var would be
  471. changed to float16.
  472. * enable_ioc16 --
  473. whether to use float16 for both I/O and computation
  474. precision.
  475. * enable_hwcd4 --
  476. whether to use NHWCD4 data layout. This is faster on some
  477. OpenCL backend.
  478. * enable_nchw88 --
  479. whether to use NCHW88 data layout, currently
  480. used in X86 AVX backend.
  481. * enable_nchw44 --
  482. whether to use NCHW44 data layout, currently
  483. used in arm backend.
  484. * enable_nchw44_dot --
  485. whether to use NCHW44_dot data layout, currently
  486. used in armv8.2+dotprod backend.
  487. * enable_nchw4 --
  488. whether to use NCHW4 data layout, currently
  489. used in nvidia backend(based on cudnn).
  490. * enable_nchw32 --
  491. whether to use NCHW32 data layout, currently
  492. used in nvidia backend with tensorcore(based on cudnn).
  493. * enable_chwn4 --
  494. whether to use CHWN4 data layout, currently
  495. used in nvidia backend with tensorcore.
  496. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  497. into one opr.
  498. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  499. input for inference on nvidia backend(this optimization pass will
  500. result in mismatch of the precision of output of training and
  501. inference)
  502. """
  503. if not self._capture_as_const:
  504. raise ValueError(
  505. "you must specify capture_as_const=True at __init__ to use dump"
  506. )
  507. if self._untraced:
  508. raise RuntimeError("should run at least once before calling dump")
  509. if self._output_names and output_names:
  510. raise TypeError(
  511. "cannot specify output_names when output is already in dict format"
  512. )
  513. if output_names and not isinstance(output_names, collections.abc.Sequence):
  514. output_names = (output_names,)
  515. if output_names and len(output_names) != len(self._output_bindings):
  516. raise ValueError(
  517. "wrong number of output_names, should be {} values".format(
  518. len(self._output_bindings)
  519. )
  520. )
  521. if arg_names and not isinstance(arg_names, collections.abc.Sequence):
  522. arg_names = (arg_names,)
  523. if arg_names and len(arg_names) != len(self._arg_bindings):
  524. raise ValueError(
  525. "wrong number of arg_names, should be {} values".format(
  526. len(self._arg_bindings)
  527. )
  528. )
  529. output_names = output_names or self._output_names
  530. h2v = {}
  531. graph = G.Graph()
  532. for i, h in enumerate(self._arg_bindings):
  533. info = self._tinfo[h]
  534. h2v[h] = graph.make_h2d(
  535. dtype=info.dtype,
  536. device=info.device,
  537. shape=info.shape,
  538. name=arg_names[i] if arg_names else None,
  539. )
  540. for k, h in self._kwarg_bindings.items():
  541. info = self._tinfo[h]
  542. h2v[h] = graph.make_h2d(
  543. dtype=info.dtype, device=info.device, shape=info.shape, name=k
  544. )
  545. for op, ihandles, ohandles in self._seq:
  546. ivars = []
  547. for h in ihandles:
  548. info = self._tinfo[h]
  549. if h not in h2v:
  550. assert info.external
  551. assert info.bound_data
  552. h2v[h] = graph.make_const(
  553. info.bound_data.numpy(), dtype=info.dtype, device=info.device
  554. )
  555. ivars.append(h2v[h])
  556. ovars = apply(op, *ivars)
  557. assert len(ovars) == len(ohandles)
  558. h2v.update(zip(ohandles, ovars))
  559. dest_vars = []
  560. for i, h in enumerate(self._output_bindings):
  561. v = h2v[h]
  562. if output_names:
  563. v.name = output_names[i]
  564. dest_vars.append(v)
  565. if optimize_for_inference:
  566. dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
  567. if isinstance(file, str):
  568. permission = "wb" if append == False else "ab"
  569. file = open(file, permission)
  570. dump_content, dump_info = G.dump_graph(dest_vars)
  571. file.write(dump_content)
  572. return dump_info
  573. def _process_inputs(self, *args, **kwargs):
  574. if self._untraced:
  575. self._inputs_to_restore = []
  576. def record_input(x):
  577. if x is None:
  578. return
  579. h, info = self._new_handle()
  580. info.external = False
  581. info.device = x.device
  582. info.dtype = x.dtype
  583. info.shape = x.shape
  584. TraceMixin._TraceMixin__inject(x, h)
  585. self._inputs_to_restore.append(x)
  586. return h
  587. self._arg_bindings = []
  588. for i, x in enumerate(args):
  589. x = find_raw_tensor(x)
  590. if x is None:
  591. raise TypeError(
  592. "positional arguments should all be tensor "
  593. "but args[%d] cannot be recognized as one" % i
  594. )
  595. self._arg_bindings.append(record_input(x))
  596. self._kwarg_bindings = {}
  597. for k, x in kwargs.items():
  598. x = find_raw_tensor(x)
  599. if x is not None:
  600. self._kwarg_bindings[k] = record_input(x)
  601. else:
  602. if len(args) != len(self._arg_bindings):
  603. raise TraceMismatchError("positional argument length mismatch")
  604. self._tensor_remaps = {}
  605. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  606. x = find_raw_tensor(x)
  607. if x is None:
  608. raise TypeError(
  609. "positional arguments should all be tensor "
  610. "but args[%d] cannot be recognized as one" % i
  611. )
  612. info = self._tinfo[h]
  613. if x.dtype != info.dtype:
  614. raise TypeError("args[%d].dtype different from last time" % i)
  615. if x.device != info.device:
  616. raise TypeError("args[%d].device different from last time" % i)
  617. info.data_setter.set_value(x._dev_tensor())
  618. self._tensor_remaps[x] = CompiledTensorProxy(h)
  619. kwargs_tensors = {}
  620. for k, x in kwargs.items():
  621. x = find_raw_tensor(x)
  622. if x is not None:
  623. kwargs_tensors[k] = x
  624. if set(kwargs_tensors) != set(self._kwarg_bindings):
  625. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  626. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  627. if too_many:
  628. raise TraceMismatchError(
  629. "keyword arguments found to be tensor this time "
  630. "but were non-tensor previously: %s" % " ".join(too_many)
  631. )
  632. if too_few:
  633. raise TraceMismatchError(
  634. "keyword arguments found to be non-tensor this time "
  635. "but were tensor previously: %s" % " ".join(too_few)
  636. )
  637. for k, h in self._kwarg_bindings.items():
  638. x = kwargs_tensors[k]
  639. info = self._tinfo[h]
  640. if x.dtype != info.dtype:
  641. raise TypeError("kwargs[%s].dtype different from last time" % k)
  642. if x.device != info.device:
  643. raise TypeError("kwargs[%s].device different from last time" % k)
  644. info.data_setter.set_value(x._dev_tensor())
  645. self._tensor_remaps[x] = CompiledTensorProxy(h)
  646. def _process_outputs(self, outputs):
  647. output_names = None
  648. if isinstance(outputs, collections.abc.Mapping):
  649. output_names, outputs = zip(*sorted(outputs.items()))
  650. elif not isinstance(outputs, collections.abc.Sequence):
  651. outputs = (outputs,)
  652. if not self._untraced:
  653. if output_names != self._output_names:
  654. too_many = set(output_names) - set(self._output_names)
  655. too_few = set(self._output_names) - set(output_names)
  656. if too_many:
  657. raise TraceMismatchError(
  658. "output has more keys than last time: %s" % " ".join(too_many)
  659. )
  660. if too_few:
  661. raise TraceMismatchError(
  662. "output has less keys than last time: %s" % " ".join(too_few)
  663. )
  664. if len(outputs) != len(self._output_bindings):
  665. raise TraceMismatchError("output size differs from last time")
  666. else:
  667. self._output_names = output_names
  668. self._output_bindings = []
  669. for i, x in enumerate(outputs):
  670. x = find_raw_tensor(x)
  671. if x is None:
  672. raise TypeError("every item of return value should be tensor")
  673. if self._untraced:
  674. if not isinstance(x, TraceMixin):
  675. raise RuntimeError("output is not computed from inputs")
  676. h = x._TraceMixin__handle
  677. self._output_bindings.append(h)
  678. else:
  679. if not isinstance(x, CompiledTensorProxy):
  680. raise RuntimeError("output is not computed from inputs")
  681. h = x._CompiledTensorProxy__handle
  682. if h != self._output_bindings[i]:
  683. raise TraceMismatchError(
  684. "retval[%s] is a different tensor than last time"
  685. % (output_names and output_names[i] or i)
  686. )
  687. def get_profile(self):
  688. """
  689. Get profiling result for compiled trace.
  690. :return: a json compatible object.
  691. """
  692. if not self._profiler:
  693. raise RuntimeError("trace is not set with profiling=True")
  694. return json.loads(self._profiler.get())
  695. def trace(self, *args, **kwargs):
  696. raise NotImplementedError(
  697. "trace is deemed unbeneficial with the new "
  698. "tracing mechanism. You should alwasy use __call__."
  699. )
  700. class CompiledTensorProxy(RawTensor):
  701. """
  702. Duck-typed RawTensor
  703. """
  704. def __init__(self, handle):
  705. self.__handle = handle
  706. self.__info = active_trace._tinfo[handle]
  707. self.__shape = None
  708. self.__data = None
  709. self.__value = None
  710. @property
  711. def dtype(self):
  712. return self.__info.varnode.dtype
  713. @property
  714. def device(self):
  715. return self.__info.varnode.device
  716. @property
  717. def shape(self):
  718. if self.__shape is None:
  719. if self.__info.shape_read:
  720. self.__shape = self.__info.shape_reader.get_value().shape
  721. elif self.__info.data_read:
  722. self.__shape = self._dev_tensor().shape
  723. else:
  724. raise TraceMismatchError("shape of this tensor is not read in trace")
  725. return self.__shape
  726. def numpy(self):
  727. if self.__value is None:
  728. if self.__info.value_read:
  729. self.__value = self.__info.value_reader.get_value()
  730. elif self.__info.data_read:
  731. self.__value = self._dev_tensor().numpy()
  732. else:
  733. raise TraceMismatchError("value of this tensor is not read in trace")
  734. return self.__value
  735. def _dev_tensor(self):
  736. if self.__data is None:
  737. if not self.__info.data_read:
  738. raise TraceMismatchError("raw data of this tensor is not read in trace")
  739. self.__data = self.__info.data_reader.get_value()
  740. return self.__data
  741. def __del__(self):
  742. if self.__info.shape_read and self.__shape is not None:
  743. self.__info.shape_reader.drop_value()
  744. if self.__info.value_read and self.__value is not None:
  745. self.__info.value_reader.drop_value()
  746. if self.__info.data_read and self.__data is not None:
  747. self.__info.data_reader.drop_value()
  748. class LazyEvalTensor(RawTensor):
  749. def __init__(self, varnode):
  750. self.__varnode = varnode
  751. @property
  752. def dtype(self):
  753. return self.__varnode.dtype
  754. @property
  755. def device(self):
  756. return self.__varnode.device
  757. @property
  758. def shape(self):
  759. return self.__varnode.shape
  760. def numpy(self):
  761. return self.__varnode.value
  762. def _dev_tensor(self):
  763. raise RuntimeError("cannot access data during symbolic tracing")
  764. class TraceMixin:
  765. __subclass_cache = {}
  766. def __inject(self, handle):
  767. cache = __class__.__subclass_cache
  768. cls = self.__class__
  769. subcls = cache.get(cls)
  770. if subcls is None:
  771. subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
  772. self.__class__ = subcls
  773. self.__handle = handle
  774. self.__cls = cls
  775. return self
  776. def __restore(self):
  777. cls = self.__cls
  778. del self.__handle
  779. del self.__cls
  780. self.__class__ = cls
  781. return self
  782. @property
  783. def shape(self):
  784. if not skip_tracing:
  785. active_trace._require_shape(self.__handle)
  786. return super().shape
  787. def numpy(self):
  788. if not skip_tracing:
  789. active_trace._require_value(self.__handle)
  790. return super().numpy()
  791. def _dev_tensor(self):
  792. if not skip_tracing:
  793. active_trace._require_data(self.__handle)
  794. return super()._dev_tensor()
  795. class TracedRawTensor(TraceMixin, RawTensor):
  796. pass
  797. class TracedLazyTensor(TraceMixin, LazyEvalTensor):
  798. pass
  799. def assign_raw_tensor(lhs, rhs):
  800. handle = rhs._handle
  801. rhs.__dict__.clear()
  802. lhs.__dict__.clear()
  803. lhs.__class__ = RawTensor
  804. lhs.__init__(handle)
  805. # this hook turns RawTensor into LazyEvalTensor
  806. @apply.register()
  807. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  808. graph = active_trace._lazy_eval_graph
  809. ivars = [
  810. getattr(x, "_LazyEvalTensor__varnode", None)
  811. or graph.make_const(x._dev_tensor())
  812. for x in args
  813. ]
  814. ovars = apply(op, *ivars)
  815. outputs = [LazyEvalTensor(v) for v in ovars]
  816. active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs)
  817. return outputs
  818. apply.disable(apply_symbolic_mode)
  819. @apply.register()
  820. def apply_const_symbolic_mode(op: Const, *args: RawTensor):
  821. graph = active_trace._lazy_eval_graph
  822. ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
  823. active_trace._lazy_eval_tensors.append(weakref.ref(ret))
  824. return (ret,)
  825. apply.disable(apply_const_symbolic_mode)
  826. @apply.register()
  827. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  828. if skip_tracing:
  829. args = [
  830. as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  831. for x in args
  832. ]
  833. return apply.super(op, *args)
  834. return active_trace._apply_op(op, args)
  835. apply.disable(apply_compiled_mode)
  836. # this hook injects TraceMixin
  837. @apply.register()
  838. def apply_with_tracing(op: OpDef, *args: RawTensor):
  839. outputs = apply.super(op, *args)
  840. active_trace._record_op(op, args, outputs)
  841. return outputs
  842. apply.disable(apply_with_tracing)
  843. @apply.register()
  844. def apply_const_with_tracing(op: Const, *args: RawTensor):
  845. outputs = apply.super(op, *args)
  846. active_trace._record_const(op, outputs)
  847. return outputs
  848. apply.disable(apply_const_with_tracing)
  849. class BrokenRawTensor(RawTensor):
  850. def __getattribute__(self, _):
  851. raise RuntimeError("broken due to misuse of tracing")
  852. def __setattr__(self, *_):
  853. raise RuntimeError("broken due to misuse of tracing")
  854. @functools.singledispatch
  855. def find_raw_tensor(x):
  856. return None
  857. @find_raw_tensor.register(RawTensor)
  858. def _(x):
  859. return x
  860. @find_raw_tensor.register(TensorWrapperBase)
  861. def _(x):
  862. x = getattr(x, "__wrapped__", None)
  863. if x is not None:
  864. return find_raw_tensor(x)
  865. @find_raw_tensor.register(Tensor)
  866. def _(x):
  867. x = getattr(x, "_data", None)
  868. if x is not None:
  869. return find_raw_tensor(x)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台