You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. import contextlib
  2. import functools
  3. import typing
  4. import weakref
  5. from ..core.ops.special import Const
  6. from ..core.tensor import megbrain_graph as G
  7. from ..core.tensor.core import OpBase, apply
  8. from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
  9. from .sublinear_memory_config import SublinearMemoryConfig
  10. class TraceMismatchError(RuntimeError):
  11. pass
  12. active_trace = None
  13. skip_tracing = False
  14. @contextlib.contextmanager
  15. def exclude_from_trace():
  16. global skip_tracing
  17. if skip_tracing:
  18. yield
  19. return
  20. try:
  21. skip_tracing = True
  22. if active_trace is not None:
  23. active_trace._begin_excluded_region()
  24. yield
  25. finally:
  26. skip_tracing = False
  27. class TensorInfo:
  28. __slots__ = (
  29. # collected attributes
  30. "external",
  31. "exported",
  32. "data_read",
  33. "shape_read",
  34. "value_read",
  35. "device",
  36. "dtype",
  37. "bound_data",
  38. # resources for execution
  39. "varnode",
  40. "data_setter",
  41. "shape_reader",
  42. "value_reader",
  43. "data_reader",
  44. )
  45. def __init__(self):
  46. self.exported = None
  47. self.data_read = None
  48. self.shape_read = None
  49. self.value_read = None
  50. self.bound_data = None
  51. self.data_setter = None
  52. self.shape_reader = None
  53. self.value_reader = None
  54. self.data_reader = None
  55. class trace:
  56. def __new__(cls, *args, **kwargs):
  57. if not args:
  58. return functools.partial(cls, **kwargs)
  59. self = super().__new__(cls)
  60. self.__init__(*args, **kwargs)
  61. return self
  62. def __init__(
  63. self,
  64. function,
  65. symbolic=False,
  66. capture_as_const=False,
  67. sublinear_memory_config: SublinearMemoryConfig = None,
  68. ):
  69. self.__wrapped__ = function
  70. self._symbolic = symbolic
  71. self._capture_as_const = capture_as_const
  72. self._capture_static_shape = False
  73. self._sublinear_memory_config = sublinear_memory_config
  74. self._untraced = True
  75. self._tinfo = [] # handle -> TensorInfo
  76. self._seq = []
  77. self._pc = 0
  78. self._graph = None
  79. self._need_reset_nodes = None
  80. self._lazy_eval_graph = None
  81. self._lazy_eval_tensors = weakref.WeakSet()
  82. self._active_tensors = weakref.WeakSet()
  83. def _new_handle(self):
  84. handle = len(self._tinfo)
  85. info = TensorInfo()
  86. self._tinfo.append(info)
  87. return handle, info
  88. def _apply_op(self, op, args):
  89. assert not self._untraced
  90. # check against trace
  91. if self._pc >= len(self._seq):
  92. raise TraceMismatchError("trace should end here, but more op observed")
  93. record = self._seq[self._pc]
  94. op_, ihandles, ohandles = record
  95. if op != op_:
  96. raise TraceMismatchError("op different from last time")
  97. if len(ihandles) != len(args):
  98. raise TraceMismatchError("op input size different from last time")
  99. for h, x in zip(ihandles, args):
  100. info = self._tinfo[h]
  101. if info.external:
  102. if (
  103. x.__class__ is CompiledTensorProxy
  104. and not self._tinfo[x._CompiledTensorProxy__handle].exported
  105. ):
  106. raise TraceMismatchError(
  107. "failed to capture: input was an external tensor "
  108. "last time, got an internal tensor this time"
  109. )
  110. if info.bound_data:
  111. if x.__class__ is CompiledTensorProxy:
  112. raise TraceMismatchError(
  113. "const capture violated: was an external tensor "
  114. "last time, got an internal tensor this time"
  115. )
  116. if x._handle != info.bound_data._handle:
  117. raise TraceMismatchError(
  118. "const capture violated: got "
  119. "a different tensor this time"
  120. )
  121. else:
  122. if info.dtype != x.dtype:
  123. raise TraceMismatchError(
  124. "failed to capture: different dtype from last time"
  125. )
  126. if info.device != x.device:
  127. raise TraceMismatchError(
  128. "failed to capture: different device from last time"
  129. )
  130. info.data_setter.set_value(x._dev_tensor())
  131. else:
  132. if x.__class__ is not CompiledTensorProxy:
  133. raise TraceMismatchError(
  134. "unexpected capture: trying to use an external tensor as input, "
  135. "but that input was an internal tensor last time"
  136. )
  137. if x._CompiledTensorProxy__handle != h:
  138. raise TraceMismatchError(
  139. "mis-wiring: input edge to an data flow "
  140. "graph node is different from last time"
  141. )
  142. self._pc += 1
  143. outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
  144. self._active_tensors.update(outputs)
  145. return outputs
  146. def _record_op(self, op, inputs, outputs):
  147. if skip_tracing:
  148. for x in inputs:
  149. h = getattr(x, "_TraceMixin__handle", None)
  150. if h is not None:
  151. self._tinfo[h].data_read = True
  152. return
  153. ihandles = []
  154. for x in inputs:
  155. h = getattr(x, "_TraceMixin__handle", None)
  156. if h is None or (not self._capture_as_const and self._tinfo[h].exported):
  157. h, info = self._new_handle()
  158. info.external = True
  159. info.device = x.device
  160. info.dtype = x.dtype
  161. if self._capture_as_const:
  162. info.bound_data = x
  163. ihandles.append(h)
  164. ohandles = []
  165. for x in outputs:
  166. h, info = self._new_handle()
  167. ohandles.append(h)
  168. info.external = False
  169. TraceMixin._TraceMixin__inject(x, h)
  170. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  171. self._active_tensors.update(outputs)
  172. def _record_const(self, op, outputs):
  173. pass
  174. @contextlib.contextmanager
  175. def _setup(self):
  176. global active_trace
  177. if active_trace:
  178. raise NotImplementedError("sorry, not implemented: nested trace")
  179. active_trace = self
  180. if self._untraced:
  181. apply.enable(apply_with_tracing)
  182. apply.enable(apply_const_with_tracing)
  183. if self._symbolic:
  184. apply.enable(apply_symbolic_mode)
  185. apply.enable(apply_const_symbolic_mode)
  186. self._lazy_eval_graph = G.Graph()
  187. else:
  188. apply.enable(apply_compiled_mode)
  189. if self._graph is None:
  190. self._compile()
  191. self._graph.execute()
  192. yield
  193. escaped_tensors = tuple(self._active_tensors)
  194. self._active_tensors.clear()
  195. if self._untraced:
  196. for x in escaped_tensors:
  197. info = self._tinfo[x._TraceMixin__handle]
  198. info.data_read = True
  199. x._TraceMixin__restore()
  200. if self._symbolic:
  201. # eval lazy eval tensors
  202. lazy_eval_tensors = tuple(self._lazy_eval_tensors)
  203. if lazy_eval_tensors:
  204. readers = [
  205. G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
  206. for x in lazy_eval_tensors
  207. ]
  208. self._apply_graph_options(self._lazy_eval_graph)
  209. self._lazy_eval_graph.compile(*readers)
  210. self._lazy_eval_graph()
  211. for r, x in zip(readers, lazy_eval_tensors):
  212. assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
  213. self._lazy_eval_graph = None
  214. self._lazy_eval_tensors = None
  215. self._untraced = False
  216. else:
  217. if self._pc != len(self._seq):
  218. raise TraceMismatchError("premature end")
  219. for x in escaped_tensors:
  220. assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
  221. self._graph.wait()
  222. self._reset_exec_env()
  223. self._pc = 0
  224. apply.disable(apply_with_tracing)
  225. apply.disable(apply_const_with_tracing)
  226. apply.disable(apply_symbolic_mode)
  227. apply.disable(apply_const_symbolic_mode)
  228. apply.disable(apply_compiled_mode)
  229. active_trace = None
  230. def _begin_excluded_region(self):
  231. if self._untraced:
  232. # conditionally reading a compiled tensor in excluded region
  233. # is permitted, so we have to assume every tensor might be read
  234. for x in self._active_tensors:
  235. info = self._tinfo[x._TraceMixin__handle]
  236. info.exported = True
  237. info.data_read = True
  238. def _apply_graph_options(self, graph):
  239. # sublinear
  240. if self._sublinear_memory_config is not None:
  241. graph.options.enable_sublinear_memory_opt = True
  242. sublinear_config = graph.options.sublinear_mem_config
  243. sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
  244. sublinear_config.genetic_nr_iter = (
  245. self._sublinear_memory_config.genetic_nr_iter
  246. )
  247. sublinear_config.genetic_pool_size = (
  248. self._sublinear_memory_config.genetic_pool_size
  249. )
  250. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  251. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  252. def _compile(self):
  253. graph = self._graph = G.Graph()
  254. graph.options.no_force_inplace = True
  255. self._apply_graph_options(graph)
  256. # graph.options.graph_opt_level = 0
  257. need_reset_nodes = self._need_reset_nodes = []
  258. # links enforce ordering of I/O nodes
  259. links = ()
  260. for op, ihandles, ohandles in self._seq:
  261. ivars = []
  262. readers = []
  263. for h in ihandles:
  264. info = self._tinfo[h]
  265. if not hasattr(info, "varnode"):
  266. assert info.external
  267. if info.bound_data:
  268. info.varnode = graph.make_const(info.bound_data._dev_tensor())
  269. else:
  270. opnode = info.data_setter = G.InputNode(
  271. *links, device=info.device, dtype=info.dtype, graph=graph
  272. )
  273. need_reset_nodes.append(opnode)
  274. info.varnode, *links = opnode.outputs
  275. ivars.append(info.varnode)
  276. ovars = apply(op, *ivars)
  277. assert len(ovars) == len(ohandles)
  278. for h, v in zip(ohandles, ovars):
  279. info = self._tinfo[h]
  280. info.varnode = v
  281. def add_reader(opnode):
  282. nonlocal links
  283. need_reset_nodes.append(opnode)
  284. readers.append(opnode.outputs[0])
  285. links = opnode.outputs
  286. if info.data_read:
  287. # Shape can be obtained from data so doesn't need its own
  288. # output node. On the other hand, value is read separately
  289. # to leverage eager h2d copy
  290. info.shape_read = False
  291. opnode = info.data_reader = G.OutputNode(v, *links)
  292. add_reader(opnode)
  293. if info.value_read:
  294. opnode = info.value_reader = G.ValueOutputNode(v, *links)
  295. add_reader(opnode)
  296. if info.shape_read:
  297. opnode = info.shape_reader = G.AttrOutputNode(v, *links)
  298. add_reader(opnode)
  299. graph.compile(*readers)
  300. def _reset_exec_env(self):
  301. for opnode in self._need_reset_nodes:
  302. opnode.reset()
  303. def _require_shape(self, handle):
  304. info = self._tinfo[handle]
  305. info.shape_read = True
  306. def _require_value(self, handle):
  307. info = self._tinfo[handle]
  308. info.value_read = True
  309. def _require_data(self, handle):
  310. info = self._tinfo[handle]
  311. info.data_read = True
  312. def __call__(self, *args, **kwargs):
  313. with self._setup():
  314. return self.__wrapped__(*args, **kwargs)
  315. class CompiledTensorProxy(RawTensor):
  316. """
  317. Duck-typed RawTensor
  318. """
  319. def __init__(self, handle):
  320. self.__handle = handle
  321. self.__info = active_trace._tinfo[handle]
  322. self.__shape = None
  323. self.__data = None
  324. self.__value = None
  325. @property
  326. def dtype(self):
  327. return self.__info.varnode.dtype
  328. @property
  329. def device(self):
  330. return self.__info.varnode.device
  331. @property
  332. def shape(self):
  333. if self.__shape is None:
  334. if self.__info.shape_read:
  335. self.__shape = self.__info.shape_reader.get_value().shape
  336. elif self.__info.data_read:
  337. self.__shape = self._dev_tensor().shape
  338. else:
  339. raise TraceMismatchError("shape of this tensor is not read in trace")
  340. return self.__shape
  341. def numpy(self):
  342. if self.__value is None:
  343. if self.__info.value_read:
  344. self.__value = self.__info.value_reader.get_value()
  345. elif self.__info.data_read:
  346. self.__value = self._dev_tensor().numpy()
  347. else:
  348. raise TraceMismatchError("value of this tensor is not read in trace")
  349. return self.__value
  350. def _dev_tensor(self):
  351. if self.__data is None:
  352. if not self.__info.data_read:
  353. raise TraceMismatchError("raw data of this tensor is not read in trace")
  354. self.__data = self.__info.data_reader.get_value()
  355. return self.__data
  356. def __del__(self):
  357. if self.__info.shape_read and self.__shape is not None:
  358. self.__info.shape_reader.drop_value()
  359. if self.__info.value_read and self.__value is not None:
  360. self.__info.value_reader.drop_value()
  361. if self.__info.data_read and self.__data is not None:
  362. self.__info.data_reader.drop_value()
  363. class LazyEvalTensor(RawTensor):
  364. def __init__(self, varnode):
  365. self.__varnode = varnode
  366. @property
  367. def dtype(self):
  368. return self.__varnode.dtype
  369. @property
  370. def device(self):
  371. return self.__varnode.device
  372. @property
  373. def shape(self):
  374. return self.__varnode.shape
  375. def numpy(self):
  376. return self.__varnode.value
  377. def _dev_tensor(self):
  378. raise RuntimeError("cannot access data during symbolic tracing")
  379. class TraceMixin:
  380. __subclass_cache = {}
  381. def __inject(self, handle):
  382. cache = __class__.__subclass_cache
  383. cls = self.__class__
  384. subcls = cache.get(cls)
  385. if subcls is None:
  386. subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
  387. self.__class__ = subcls
  388. self.__handle = handle
  389. self.__cls = cls
  390. return self
  391. def __restore(self):
  392. cls = self.__cls
  393. del self.__handle
  394. del self.__cls
  395. self.__class__ = cls
  396. return self
  397. @property
  398. def shape(self):
  399. if not skip_tracing:
  400. active_trace._require_shape(self.__handle)
  401. return super().shape
  402. def numpy(self):
  403. if not skip_tracing:
  404. active_trace._require_value(self.__handle)
  405. return super().numpy()
  406. def _dev_tensor(self):
  407. if not skip_tracing:
  408. active_trace._require_data(self.__handle)
  409. return super()._dev_tensor()
  410. class TracedRawTensor(TraceMixin, RawTensor):
  411. pass
  412. class TracedLazyTensor(TraceMixin, LazyEvalTensor):
  413. pass
  414. def assign_raw_tensor(lhs, rhs):
  415. handle = rhs._handle
  416. rhs.__dict__.clear()
  417. lhs.__dict__.clear()
  418. lhs.__class__ = RawTensor
  419. lhs.__init__(handle)
  420. # this hook turns RawTensor into LazyEvalTensor
  421. @apply.register()
  422. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  423. graph = active_trace._lazy_eval_graph
  424. ivars = [
  425. getattr(x, "_LazyEvalTensor__varnode", None)
  426. or graph.make_const(x._dev_tensor())
  427. for x in args
  428. ]
  429. ovars = apply(op, *ivars)
  430. outputs = [LazyEvalTensor(v) for v in ovars]
  431. active_trace._lazy_eval_tensors.update(outputs)
  432. return outputs
  433. apply.disable(apply_symbolic_mode)
  434. @apply.register()
  435. def apply_const_symbolic_mode(op: Const, *args: RawTensor):
  436. graph = active_trace._lazy_eval_graph
  437. ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
  438. return (ret,)
  439. apply.disable(apply_const_symbolic_mode)
  440. @apply.register()
  441. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  442. if skip_tracing:
  443. args = [
  444. as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  445. for x in args
  446. ]
  447. return apply.super(op, *args)
  448. return active_trace._apply_op(op, args)
  449. apply.disable(apply_compiled_mode)
  450. # this hook injects TraceMixin
  451. @apply.register()
  452. def apply_with_tracing(op: OpDef, *args: RawTensor):
  453. outputs = apply.super(op, *args)
  454. active_trace._record_op(op, args, outputs)
  455. return outputs
  456. apply.disable(apply_with_tracing)
  457. @apply.register()
  458. def apply_const_with_tracing(op: Const, *args: RawTensor):
  459. outputs = apply.super(op, *args)
  460. active_trace._record_const(op, outputs)
  461. return outputs
  462. apply.disable(apply_const_with_tracing)
  463. class BrokenRawTensor(RawTensor):
  464. def __getattribute__(self, _):
  465. raise RuntimeError("broken due to misuse of tracing")
  466. def __setattr__(self, *_):
  467. raise RuntimeError("broken due to misuse of tracing")

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台