You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. import contextlib
  2. import functools
  3. import typing
  4. import weakref
  5. from ..core.ops.special import Const
  6. from ..core.tensor import megbrain_graph as G
  7. from ..core.tensor.core import OpBase, apply
  8. from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
  9. class TraceMismatchError(RuntimeError):
  10. pass
  11. active_trace = None
  12. skip_tracing = False
  13. @contextlib.contextmanager
  14. def exclude_from_trace():
  15. global skip_tracing
  16. if skip_tracing:
  17. yield
  18. return
  19. try:
  20. skip_tracing = True
  21. if active_trace is not None:
  22. active_trace._begin_excluded_region()
  23. yield
  24. finally:
  25. skip_tracing = False
  26. class TensorInfo:
  27. __slots__ = (
  28. # collected attributes
  29. "external",
  30. "exported",
  31. "data_read",
  32. "shape_read",
  33. "value_read",
  34. "device",
  35. "dtype",
  36. "bound_data",
  37. # resources for execution
  38. "varnode",
  39. "data_setter",
  40. "shape_reader",
  41. "value_reader",
  42. "data_reader",
  43. )
  44. def __init__(self):
  45. self.exported = None
  46. self.data_read = None
  47. self.shape_read = None
  48. self.value_read = None
  49. self.bound_data = None
  50. self.data_setter = None
  51. self.shape_reader = None
  52. self.value_reader = None
  53. self.data_reader = None
  54. class trace:
  55. def __new__(cls, *args, **kwargs):
  56. if not args:
  57. return functools.partial(cls, **kwargs)
  58. self = super().__new__(cls)
  59. self.__init__(*args, **kwargs)
  60. return self
  61. def __init__(self, function, symbolic=False, capture_as_const=False):
  62. self.__wrapped__ = function
  63. self._symbolic = symbolic
  64. self._capture_as_const = capture_as_const
  65. self._capture_static_shape = False
  66. self._untraced = True
  67. self._tinfo = [] # handle -> TensorInfo
  68. self._seq = []
  69. self._pc = 0
  70. self._graph = None
  71. self._need_reset_nodes = None
  72. self._lazy_eval_graph = None
  73. self._lazy_eval_tensors = weakref.WeakSet()
  74. self._active_tensors = weakref.WeakSet()
  75. def _new_handle(self):
  76. handle = len(self._tinfo)
  77. info = TensorInfo()
  78. self._tinfo.append(info)
  79. return handle, info
  80. def _apply_op(self, op, args):
  81. assert not self._untraced
  82. # check against trace
  83. if self._pc >= len(self._seq):
  84. raise TraceMismatchError("trace should end here, but more op observed")
  85. record = self._seq[self._pc]
  86. op_, ihandles, ohandles = record
  87. if op != op_:
  88. raise TraceMismatchError("op different from last time")
  89. if len(ihandles) != len(args):
  90. raise TraceMismatchError("op input size different from last time")
  91. for h, x in zip(ihandles, args):
  92. info = self._tinfo[h]
  93. if info.external:
  94. if (
  95. x.__class__ is CompiledTensorProxy
  96. and not self._tinfo[x._CompiledTensorProxy__handle].exported
  97. ):
  98. raise TraceMismatchError(
  99. "failed to capture: input was an external tensor "
  100. "last time, got an internal tensor this time"
  101. )
  102. if info.bound_data:
  103. if x.__class__ is CompiledTensorProxy:
  104. raise TraceMismatchError(
  105. "const capture violated: was an external tensor "
  106. "last time, got an internal tensor this time"
  107. )
  108. if x._handle != info.bound_data._handle:
  109. raise TraceMismatchError(
  110. "const capture violated: got "
  111. "a different tensor this time"
  112. )
  113. else:
  114. if info.dtype != x.dtype:
  115. raise TraceMismatchError(
  116. "failed to capture: different dtype from last time"
  117. )
  118. if info.device != x.device:
  119. raise TraceMismatchError(
  120. "failed to capture: different device from last time"
  121. )
  122. info.data_setter.set_value(x._dev_tensor())
  123. else:
  124. if x.__class__ is not CompiledTensorProxy:
  125. raise TraceMismatchError(
  126. "unexpected capture: trying to use an external tensor as input, "
  127. "but that input was an internal tensor last time"
  128. )
  129. if x._CompiledTensorProxy__handle != h:
  130. raise TraceMismatchError(
  131. "mis-wiring: input edge to an data flow "
  132. "graph node is different from last time"
  133. )
  134. self._pc += 1
  135. outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
  136. self._active_tensors.update(outputs)
  137. return outputs
  138. def _record_op(self, op, inputs, outputs):
  139. if skip_tracing:
  140. for x in inputs:
  141. h = getattr(x, "_TraceMixin__handle", None)
  142. if h is not None:
  143. self._tinfo[h].data_read = True
  144. return
  145. ihandles = []
  146. for x in inputs:
  147. h = getattr(x, "_TraceMixin__handle", None)
  148. if h is None or (not self._capture_as_const and self._tinfo[h].exported):
  149. h, info = self._new_handle()
  150. info.external = True
  151. info.device = x.device
  152. info.dtype = x.dtype
  153. if self._capture_as_const:
  154. info.bound_data = x
  155. ihandles.append(h)
  156. ohandles = []
  157. for x in outputs:
  158. h, info = self._new_handle()
  159. ohandles.append(h)
  160. info.external = False
  161. TraceMixin._TraceMixin__inject(x, h)
  162. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  163. self._active_tensors.update(outputs)
  164. @contextlib.contextmanager
  165. def _setup(self):
  166. global active_trace
  167. if active_trace:
  168. raise NotImplementedError("sorry, not implemented: nested trace")
  169. active_trace = self
  170. if self._untraced:
  171. apply.enable(apply_with_tracing)
  172. if self._symbolic:
  173. apply.enable(apply_symbolic_mode)
  174. self._lazy_eval_graph = G.Graph()
  175. else:
  176. apply.enable(apply_compiled_mode)
  177. if self._graph is None:
  178. self._compile()
  179. self._graph.execute()
  180. yield
  181. escaped_tensors = tuple(self._active_tensors)
  182. self._active_tensors.clear()
  183. if self._untraced:
  184. for x in escaped_tensors:
  185. info = self._tinfo[x._TraceMixin__handle]
  186. info.data_read = True
  187. x._TraceMixin__restore()
  188. if self._symbolic:
  189. # eval lazy eval tensors
  190. lazy_eval_tensors = tuple(self._lazy_eval_tensors)
  191. if lazy_eval_tensors:
  192. readers = [
  193. G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
  194. for x in lazy_eval_tensors
  195. ]
  196. self._lazy_eval_graph.compile(*readers)
  197. self._lazy_eval_graph()
  198. for r, x in zip(readers, lazy_eval_tensors):
  199. assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
  200. self._lazy_eval_graph = None
  201. self._lazy_eval_tensors = None
  202. self._untraced = False
  203. else:
  204. if self._pc != len(self._seq):
  205. raise TraceMismatchError("premature end")
  206. for x in escaped_tensors:
  207. assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
  208. self._graph.wait()
  209. self._reset_exec_env()
  210. self._pc = 0
  211. apply.disable(apply_with_tracing)
  212. apply.disable(apply_symbolic_mode)
  213. apply.disable(apply_compiled_mode)
  214. active_trace = None
  215. def _begin_excluded_region(self):
  216. if self._untraced:
  217. # conditionally reading a compiled tensor in excluded region
  218. # is permitted, so we have to assume every tensor might be read
  219. for x in self._active_tensors:
  220. info = self._tinfo[x._TraceMixin__handle]
  221. info.exported = True
  222. info.data_read = True
  223. def _compile(self):
  224. graph = self._graph = G.Graph()
  225. # graph.options.graph_opt_level = 0
  226. need_reset_nodes = self._need_reset_nodes = []
  227. # links enforce ordering of I/O nodes
  228. links = ()
  229. for op, ihandles, ohandles in self._seq:
  230. ivars = []
  231. readers = []
  232. for h in ihandles:
  233. info = self._tinfo[h]
  234. if not hasattr(info, "varnode"):
  235. assert info.external
  236. if info.bound_data:
  237. info.varnode = graph.make_const(info.bound_data._dev_tensor())
  238. else:
  239. opnode = info.data_setter = G.InputNode(
  240. *links, device=info.device, dtype=info.dtype, graph=graph
  241. )
  242. need_reset_nodes.append(opnode)
  243. info.varnode, *links = opnode.outputs
  244. ivars.append(info.varnode)
  245. ovars = apply(op, *ivars)
  246. assert len(ovars) == len(ohandles)
  247. for h, v in zip(ohandles, ovars):
  248. info = self._tinfo[h]
  249. info.varnode = v
  250. def add_reader(opnode):
  251. nonlocal links
  252. need_reset_nodes.append(opnode)
  253. readers.append(opnode.outputs[0])
  254. links = opnode.outputs
  255. if info.data_read:
  256. # Shape can be obtained from data so doesn't need its own
  257. # output node. On the other hand, value is read separately
  258. # to leverage eager h2d copy
  259. info.shape_read = False
  260. opnode = info.data_reader = G.OutputNode(v, *links)
  261. add_reader(opnode)
  262. if info.value_read:
  263. opnode = info.value_reader = G.ValueOutputNode(v, *links)
  264. add_reader(opnode)
  265. if info.shape_read:
  266. opnode = info.shape_reader = G.AttrOutputNode(v, *links)
  267. add_reader(opnode)
  268. graph.compile(*readers)
  269. def _reset_exec_env(self):
  270. for opnode in self._need_reset_nodes:
  271. opnode.reset()
  272. def _require_shape(self, handle):
  273. info = self._tinfo[handle]
  274. info.shape_read = True
  275. def _require_value(self, handle):
  276. info = self._tinfo[handle]
  277. info.value_read = True
  278. def _require_data(self, handle):
  279. info = self._tinfo[handle]
  280. info.data_read = True
  281. def __call__(self, *args, **kwargs):
  282. with self._setup():
  283. return self.__wrapped__(*args, **kwargs)
  284. class CompiledTensorProxy(RawTensor):
  285. """
  286. Duck-typed RawTensor
  287. """
  288. def __init__(self, handle):
  289. self.__handle = handle
  290. self.__info = active_trace._tinfo[handle]
  291. self.__shape = None
  292. self.__data = None
  293. self.__value = None
  294. @property
  295. def dtype(self):
  296. return self.__info.varnode.dtype
  297. @property
  298. def device(self):
  299. return self.__info.varnode.device
  300. @property
  301. def shape(self):
  302. if self.__shape is None:
  303. if self.__info.shape_read:
  304. self.__shape = self.__info.shape_reader.get_value().shape
  305. elif self.__info.data_read:
  306. self.__shape = self._dev_tensor().shape
  307. else:
  308. raise TraceMismatchError("shape of this tensor is not read in trace")
  309. return self.__shape
  310. def numpy(self):
  311. if self.__value is None:
  312. if self.__info.value_read:
  313. self.__value = self.__info.value_reader.get_value()
  314. elif self.__info.data_read:
  315. self.__value = self._dev_tensor().numpy()
  316. else:
  317. raise TraceMismatchError("value of this tensor is not read in trace")
  318. return self.__value
  319. def _dev_tensor(self):
  320. if self.__data is None:
  321. if not self.__info.data_read:
  322. raise TraceMismatchError("raw data of this tensor is not read in trace")
  323. self.__data = self.__info.data_reader.get_value()
  324. return self.__data
  325. def __del__(self):
  326. if self.__info.shape_read and self.__shape is not None:
  327. self.__info.shape_reader.drop_value()
  328. if self.__info.value_read and self.__value is not None:
  329. self.__info.value_reader.drop_value()
  330. if self.__info.data_read and self.__data is not None:
  331. self.__info.data_reader.drop_value()
  332. class LazyEvalTensor(RawTensor):
  333. def __init__(self, varnode):
  334. self.__varnode = varnode
  335. @property
  336. def dtype(self):
  337. return self.__varnode.dtype
  338. @property
  339. def device(self):
  340. return self.__varnode.device
  341. @property
  342. def shape(self):
  343. return self.__varnode.shape
  344. def numpy(self):
  345. raise RuntimeError("cannot read value during symbolic tracing")
  346. def _dev_tensor(self):
  347. raise RuntimeError("cannot access data during symbolic tracing")
  348. class TraceMixin:
  349. __subclass_cache = {}
  350. def __inject(self, handle):
  351. cache = __class__.__subclass_cache
  352. cls = self.__class__
  353. subcls = cache.get(cls)
  354. if subcls is None:
  355. subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
  356. self.__class__ = subcls
  357. self.__handle = handle
  358. self.__cls = cls
  359. return self
  360. def __restore(self):
  361. cls = self.__cls
  362. del self.__handle
  363. del self.__cls
  364. self.__class__ = cls
  365. return self
  366. @property
  367. def shape(self):
  368. if not skip_tracing:
  369. active_trace._require_shape(self.__handle)
  370. return super().shape
  371. def numpy(self):
  372. if not skip_tracing:
  373. active_trace._require_value(self.__handle)
  374. return super().numpy()
  375. def _dev_tensor(self):
  376. if not skip_tracing:
  377. active_trace._require_data(self.__handle)
  378. return super()._dev_tensor()
  379. class TracedRawTensor(TraceMixin, RawTensor):
  380. pass
  381. class TracedLazyTensor(TraceMixin, LazyEvalTensor):
  382. pass
  383. def assign_raw_tensor(lhs, rhs):
  384. handle = rhs._handle
  385. rhs.__dict__.clear()
  386. lhs.__dict__.clear()
  387. lhs.__class__ = RawTensor
  388. lhs.__init__(handle)
  389. # this hook turns RawTensor into LazyEvalTensor
  390. @apply.register()
  391. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  392. graph = active_trace._lazy_eval_graph
  393. ivars = [
  394. getattr(x, "_LazyEvalTensor__varnode", None)
  395. or graph.make_const(x._dev_tensor())
  396. for x in args
  397. ]
  398. ovars = apply(op, *ivars)
  399. outputs = [LazyEvalTensor(v) for v in ovars]
  400. active_trace._lazy_eval_tensors.update(outputs)
  401. return outputs
  402. apply.disable(apply_symbolic_mode)
  403. @apply.register()
  404. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  405. if skip_tracing:
  406. args = [
  407. as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  408. for x in args
  409. ]
  410. return apply.super(op, *args)
  411. return active_trace._apply_op(op, args)
  412. apply.disable(apply_compiled_mode)
  413. # this hook injects TraceMixin
  414. @apply.register()
  415. def apply_with_tracing(op: OpDef, *args: RawTensor):
  416. outputs = apply.super(op, *args)
  417. active_trace._record_op(op, args, outputs)
  418. return outputs
  419. apply.disable(apply_with_tracing)
  420. # @apply.register()
  421. # def _(op: Const, *args: RawTensor):
  422. # return active_trace._apply_const(op, args)
  423. class BrokenRawTensor(RawTensor):
  424. def __getattribute__(self, _):
  425. raise RuntimeError("broken due to misuse of tracing")
  426. def __setattr__(self, *_):
  427. raise RuntimeError("broken due to misuse of tracing")

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台