You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. import contextlib
  2. import functools
  3. import typing
  4. import weakref
  5. from ..core.ops.special import Const
  6. from ..core.tensor import megbrain_graph as G
  7. from ..core.tensor.core import OpBase, apply
  8. from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
  9. class TraceMismatchError(RuntimeError):
  10. pass
  11. active_trace = None
  12. skip_tracing = False
  13. @contextlib.contextmanager
  14. def exclude_from_trace():
  15. global skip_tracing
  16. if skip_tracing:
  17. yield
  18. return
  19. try:
  20. skip_tracing = True
  21. if active_trace is not None:
  22. active_trace._begin_excluded_region()
  23. yield
  24. finally:
  25. skip_tracing = False
  26. class TensorInfo:
  27. __slots__ = (
  28. # collected attributes
  29. "external",
  30. "exported",
  31. "data_read",
  32. "shape_read",
  33. "value_read",
  34. "device",
  35. "dtype",
  36. "bound_data",
  37. # resources for execution
  38. "varnode",
  39. "data_setter",
  40. "shape_reader",
  41. "value_reader",
  42. "data_reader",
  43. )
  44. def __init__(self):
  45. self.exported = None
  46. self.data_read = None
  47. self.shape_read = None
  48. self.value_read = None
  49. self.bound_data = None
  50. self.data_setter = None
  51. self.shape_reader = None
  52. self.value_reader = None
  53. self.data_reader = None
  54. class trace:
  55. def __new__(cls, *args, **kwargs):
  56. if not args:
  57. return functools.partial(cls, **kwargs)
  58. self = super().__new__(cls)
  59. self.__init__(*args, **kwargs)
  60. return self
  61. def __init__(self, function, symbolic=False, capture_as_const=False):
  62. self.__wrapped__ = function
  63. self._symbolic = symbolic
  64. self._capture_as_const = capture_as_const
  65. self._capture_static_shape = False
  66. self._untraced = True
  67. self._tinfo = [] # handle -> TensorInfo
  68. self._seq = []
  69. self._pc = 0
  70. self._graph = None
  71. self._need_reset_nodes = None
  72. self._lazy_eval_graph = None
  73. self._lazy_eval_tensors = weakref.WeakSet()
  74. self._active_tensors = weakref.WeakSet()
  75. def _new_handle(self):
  76. handle = len(self._tinfo)
  77. info = TensorInfo()
  78. self._tinfo.append(info)
  79. return handle, info
  80. def _apply_op(self, op, args):
  81. assert not self._untraced
  82. # check against trace
  83. if self._pc >= len(self._seq):
  84. raise TraceMismatchError("trace should end here, but more op observed")
  85. record = self._seq[self._pc]
  86. op_, ihandles, ohandles = record
  87. if op != op_:
  88. raise TraceMismatchError("op different from last time")
  89. if len(ihandles) != len(args):
  90. raise TraceMismatchError("op input size different from last time")
  91. for h, x in zip(ihandles, args):
  92. info = self._tinfo[h]
  93. if info.external:
  94. if (
  95. x.__class__ is CompiledTensorProxy
  96. and not self._tinfo[x._CompiledTensorProxy__handle].exported
  97. ):
  98. raise TraceMismatchError(
  99. "failed to capture: input was an external tensor "
  100. "last time, got an internal tensor this time"
  101. )
  102. if info.bound_data:
  103. if x.__class__ is CompiledTensorProxy:
  104. raise TraceMismatchError(
  105. "const capture violated: was an external tensor "
  106. "last time, got an internal tensor this time"
  107. )
  108. if x._handle != info.bound_data._handle:
  109. raise TraceMismatchError(
  110. "const capture violated: got "
  111. "a different tensor this time"
  112. )
  113. else:
  114. if info.dtype != x.dtype:
  115. raise TraceMismatchError(
  116. "failed to capture: different dtype from last time"
  117. )
  118. if info.device != x.device:
  119. raise TraceMismatchError(
  120. "failed to capture: different device from last time"
  121. )
  122. info.data_setter.set_value(x._dev_tensor())
  123. else:
  124. if x.__class__ is not CompiledTensorProxy:
  125. raise TraceMismatchError(
  126. "unexpected capture: trying to use an external tensor as input, "
  127. "but that input was an internal tensor last time"
  128. )
  129. if x._CompiledTensorProxy__handle != h:
  130. raise TraceMismatchError(
  131. "mis-wiring: input edge to an data flow "
  132. "graph node is different from last time"
  133. )
  134. self._pc += 1
  135. outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
  136. self._active_tensors.update(outputs)
  137. return outputs
  138. def _record_op(self, op, inputs, outputs):
  139. if skip_tracing:
  140. for x in inputs:
  141. h = getattr(x, "_TraceMixin__handle", None)
  142. if h is not None:
  143. self._tinfo[h].data_read = True
  144. return
  145. ihandles = []
  146. for x in inputs:
  147. h = getattr(x, "_TraceMixin__handle", None)
  148. if h is None or (not self._capture_as_const and self._tinfo[h].exported):
  149. h, info = self._new_handle()
  150. info.external = True
  151. info.device = x.device
  152. info.dtype = x.dtype
  153. if self._capture_as_const:
  154. info.bound_data = x
  155. ihandles.append(h)
  156. ohandles = []
  157. for x in outputs:
  158. h, info = self._new_handle()
  159. ohandles.append(h)
  160. info.external = False
  161. TraceMixin._TraceMixin__inject(x, h)
  162. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  163. self._active_tensors.update(outputs)
  164. def _record_const(self, op, outputs):
  165. pass
  166. @contextlib.contextmanager
  167. def _setup(self):
  168. global active_trace
  169. if active_trace:
  170. raise NotImplementedError("sorry, not implemented: nested trace")
  171. active_trace = self
  172. if self._untraced:
  173. apply.enable(apply_with_tracing)
  174. apply.enable(apply_const_with_tracing)
  175. if self._symbolic:
  176. apply.enable(apply_symbolic_mode)
  177. apply.enable(apply_const_symbolic_mode)
  178. self._lazy_eval_graph = G.Graph()
  179. else:
  180. apply.enable(apply_compiled_mode)
  181. if self._graph is None:
  182. self._compile()
  183. self._graph.execute()
  184. yield
  185. escaped_tensors = tuple(self._active_tensors)
  186. self._active_tensors.clear()
  187. if self._untraced:
  188. for x in escaped_tensors:
  189. info = self._tinfo[x._TraceMixin__handle]
  190. info.data_read = True
  191. x._TraceMixin__restore()
  192. if self._symbolic:
  193. # eval lazy eval tensors
  194. lazy_eval_tensors = tuple(self._lazy_eval_tensors)
  195. if lazy_eval_tensors:
  196. readers = [
  197. G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
  198. for x in lazy_eval_tensors
  199. ]
  200. self._lazy_eval_graph.compile(*readers)
  201. self._lazy_eval_graph()
  202. for r, x in zip(readers, lazy_eval_tensors):
  203. assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
  204. self._lazy_eval_graph = None
  205. self._lazy_eval_tensors = None
  206. self._untraced = False
  207. else:
  208. if self._pc != len(self._seq):
  209. raise TraceMismatchError("premature end")
  210. for x in escaped_tensors:
  211. assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
  212. self._graph.wait()
  213. self._reset_exec_env()
  214. self._pc = 0
  215. apply.disable(apply_with_tracing)
  216. apply.disable(apply_const_with_tracing)
  217. apply.disable(apply_symbolic_mode)
  218. apply.disable(apply_const_symbolic_mode)
  219. apply.disable(apply_compiled_mode)
  220. active_trace = None
  221. def _begin_excluded_region(self):
  222. if self._untraced:
  223. # conditionally reading a compiled tensor in excluded region
  224. # is permitted, so we have to assume every tensor might be read
  225. for x in self._active_tensors:
  226. info = self._tinfo[x._TraceMixin__handle]
  227. info.exported = True
  228. info.data_read = True
  229. def _compile(self):
  230. graph = self._graph = G.Graph()
  231. graph.options.no_force_inplace = True
  232. # graph.options.graph_opt_level = 0
  233. need_reset_nodes = self._need_reset_nodes = []
  234. # links enforce ordering of I/O nodes
  235. links = ()
  236. for op, ihandles, ohandles in self._seq:
  237. ivars = []
  238. readers = []
  239. for h in ihandles:
  240. info = self._tinfo[h]
  241. if not hasattr(info, "varnode"):
  242. assert info.external
  243. if info.bound_data:
  244. info.varnode = graph.make_const(info.bound_data._dev_tensor())
  245. else:
  246. opnode = info.data_setter = G.InputNode(
  247. *links, device=info.device, dtype=info.dtype, graph=graph
  248. )
  249. need_reset_nodes.append(opnode)
  250. info.varnode, *links = opnode.outputs
  251. ivars.append(info.varnode)
  252. ovars = apply(op, *ivars)
  253. assert len(ovars) == len(ohandles)
  254. for h, v in zip(ohandles, ovars):
  255. info = self._tinfo[h]
  256. info.varnode = v
  257. def add_reader(opnode):
  258. nonlocal links
  259. need_reset_nodes.append(opnode)
  260. readers.append(opnode.outputs[0])
  261. links = opnode.outputs
  262. if info.data_read:
  263. # Shape can be obtained from data so doesn't need its own
  264. # output node. On the other hand, value is read separately
  265. # to leverage eager h2d copy
  266. info.shape_read = False
  267. opnode = info.data_reader = G.OutputNode(v, *links)
  268. add_reader(opnode)
  269. if info.value_read:
  270. opnode = info.value_reader = G.ValueOutputNode(v, *links)
  271. add_reader(opnode)
  272. if info.shape_read:
  273. opnode = info.shape_reader = G.AttrOutputNode(v, *links)
  274. add_reader(opnode)
  275. graph.compile(*readers)
  276. def _reset_exec_env(self):
  277. for opnode in self._need_reset_nodes:
  278. opnode.reset()
  279. def _require_shape(self, handle):
  280. info = self._tinfo[handle]
  281. info.shape_read = True
  282. def _require_value(self, handle):
  283. info = self._tinfo[handle]
  284. info.value_read = True
  285. def _require_data(self, handle):
  286. info = self._tinfo[handle]
  287. info.data_read = True
  288. def __call__(self, *args, **kwargs):
  289. with self._setup():
  290. return self.__wrapped__(*args, **kwargs)
  291. class CompiledTensorProxy(RawTensor):
  292. """
  293. Duck-typed RawTensor
  294. """
  295. def __init__(self, handle):
  296. self.__handle = handle
  297. self.__info = active_trace._tinfo[handle]
  298. self.__shape = None
  299. self.__data = None
  300. self.__value = None
  301. @property
  302. def dtype(self):
  303. return self.__info.varnode.dtype
  304. @property
  305. def device(self):
  306. return self.__info.varnode.device
  307. @property
  308. def shape(self):
  309. if self.__shape is None:
  310. if self.__info.shape_read:
  311. self.__shape = self.__info.shape_reader.get_value().shape
  312. elif self.__info.data_read:
  313. self.__shape = self._dev_tensor().shape
  314. else:
  315. raise TraceMismatchError("shape of this tensor is not read in trace")
  316. return self.__shape
  317. def numpy(self):
  318. if self.__value is None:
  319. if self.__info.value_read:
  320. self.__value = self.__info.value_reader.get_value()
  321. elif self.__info.data_read:
  322. self.__value = self._dev_tensor().numpy()
  323. else:
  324. raise TraceMismatchError("value of this tensor is not read in trace")
  325. return self.__value
  326. def _dev_tensor(self):
  327. if self.__data is None:
  328. if not self.__info.data_read:
  329. raise TraceMismatchError("raw data of this tensor is not read in trace")
  330. self.__data = self.__info.data_reader.get_value()
  331. return self.__data
  332. def __del__(self):
  333. if self.__info.shape_read and self.__shape is not None:
  334. self.__info.shape_reader.drop_value()
  335. if self.__info.value_read and self.__value is not None:
  336. self.__info.value_reader.drop_value()
  337. if self.__info.data_read and self.__data is not None:
  338. self.__info.data_reader.drop_value()
  339. class LazyEvalTensor(RawTensor):
  340. def __init__(self, varnode):
  341. self.__varnode = varnode
  342. @property
  343. def dtype(self):
  344. return self.__varnode.dtype
  345. @property
  346. def device(self):
  347. return self.__varnode.device
  348. @property
  349. def shape(self):
  350. return self.__varnode.shape
  351. def numpy(self):
  352. return self.__varnode.value
  353. def _dev_tensor(self):
  354. raise RuntimeError("cannot access data during symbolic tracing")
  355. class TraceMixin:
  356. __subclass_cache = {}
  357. def __inject(self, handle):
  358. cache = __class__.__subclass_cache
  359. cls = self.__class__
  360. subcls = cache.get(cls)
  361. if subcls is None:
  362. subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
  363. self.__class__ = subcls
  364. self.__handle = handle
  365. self.__cls = cls
  366. return self
  367. def __restore(self):
  368. cls = self.__cls
  369. del self.__handle
  370. del self.__cls
  371. self.__class__ = cls
  372. return self
  373. @property
  374. def shape(self):
  375. if not skip_tracing:
  376. active_trace._require_shape(self.__handle)
  377. return super().shape
  378. def numpy(self):
  379. if not skip_tracing:
  380. active_trace._require_value(self.__handle)
  381. return super().numpy()
  382. def _dev_tensor(self):
  383. if not skip_tracing:
  384. active_trace._require_data(self.__handle)
  385. return super()._dev_tensor()
  386. class TracedRawTensor(TraceMixin, RawTensor):
  387. pass
  388. class TracedLazyTensor(TraceMixin, LazyEvalTensor):
  389. pass
  390. def assign_raw_tensor(lhs, rhs):
  391. handle = rhs._handle
  392. rhs.__dict__.clear()
  393. lhs.__dict__.clear()
  394. lhs.__class__ = RawTensor
  395. lhs.__init__(handle)
  396. # this hook turns RawTensor into LazyEvalTensor
  397. @apply.register()
  398. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  399. graph = active_trace._lazy_eval_graph
  400. ivars = [
  401. getattr(x, "_LazyEvalTensor__varnode", None)
  402. or graph.make_const(x._dev_tensor())
  403. for x in args
  404. ]
  405. ovars = apply(op, *ivars)
  406. outputs = [LazyEvalTensor(v) for v in ovars]
  407. active_trace._lazy_eval_tensors.update(outputs)
  408. return outputs
  409. apply.disable(apply_symbolic_mode)
  410. @apply.register()
  411. def apply_const_symbolic_mode(op: Const, *args: RawTensor):
  412. graph = active_trace._lazy_eval_graph
  413. ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
  414. return (ret,)
  415. apply.disable(apply_const_symbolic_mode)
  416. @apply.register()
  417. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  418. if skip_tracing:
  419. args = [
  420. as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  421. for x in args
  422. ]
  423. return apply.super(op, *args)
  424. return active_trace._apply_op(op, args)
  425. apply.disable(apply_compiled_mode)
  426. # this hook injects TraceMixin
  427. @apply.register()
  428. def apply_with_tracing(op: OpDef, *args: RawTensor):
  429. outputs = apply.super(op, *args)
  430. active_trace._record_op(op, args, outputs)
  431. return outputs
  432. apply.disable(apply_with_tracing)
  433. @apply.register()
  434. def apply_const_with_tracing(op: Const, *args: RawTensor):
  435. outputs = apply.super(op, *args)
  436. active_trace._record_const(op, outputs)
  437. return outputs
  438. apply.disable(apply_const_with_tracing)
  439. class BrokenRawTensor(RawTensor):
  440. def __getattribute__(self, _):
  441. raise RuntimeError("broken due to misuse of tracing")
  442. def __setattr__(self, *_):
  443. raise RuntimeError("broken due to misuse of tracing")

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台