You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import contextlib
  10. import functools
  11. import itertools
  12. import os
  13. from typing import Callable, Tuple, Union
  14. import numpy as np
  15. import megengine._internal as mgb
  16. from megengine._internal.plugin import CompGraphProfiler
  17. from ..core import Tensor, graph, tensor
  18. def sideeffect(f):
  19. # during eager tracing, wrapped function is called with proxy inputs
  20. # during static tracing, wrapped function will not be called at all
  21. @functools.wraps(f)
  22. def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
  23. if not trace._active_instance:
  24. return f(*args, **kwargs)
  25. tensors = {}
  26. for i, x in itertools.chain(enumerate(args), kwargs.items()):
  27. if isinstance(x, Tensor):
  28. tensors[i] = x
  29. if tensors:
  30. _keys, tensors = zip(*tensors.items())
  31. else:
  32. _keys, tensors = (), ()
  33. def callback(*tensors, f=f, keys=_keys, args=args, kwargs=kwargs):
  34. replace = dict(zip(keys, tensors))
  35. args = tuple(replace.get(i, x) for i, x in enumerate(args))
  36. kwargs = {i: replace.get(i, x) for i, x in kwargs.items()}
  37. if f(*args, **kwargs) is not None:
  38. raise TypeError("a sideeffect function should return None")
  39. # TODO: clear memory
  40. trace._active_instance._register_callback(callback, tensors)
  41. return wrapper
  42. def mark_impure(x):
  43. if not trace._active_instance:
  44. return x
  45. return trace._active_instance._mark_impure(x)
  46. def barrier(x):
  47. if not trace._active_instance:
  48. return x
  49. return trace._active_instance._insert_barrier(x)
  50. def _dummy():
  51. return mgb.make_immutable(*graph._use_default_if_none(None, None), 0)
  52. class unset:
  53. pass
  54. class trace:
  55. """
  56. Wrap a callable and provide:
  57. * tracing via :meth:`.trace` and :meth:`.dump`
  58. * accelerated evalutaion via :meth:`.__call__`
  59. :param func: Positional only argument.
  60. :param symbolic: Whether to use symbolic tensor.
  61. :param opt_level: Optimization level for compiling trace.
  62. :param log_level: Log level.
  63. :param profiling: Whether to profile compiled trace.
  64. """
  65. _active_instance = None
  66. enabled = not os.getenv("MGE_DISABLE_TRACE")
  67. _UNSTARTED = "unstarted"
  68. _STARTED = "started"
  69. _FINISHED = "finished"
  70. def __new__(cls, *args, **kwargs):
  71. if not args:
  72. return functools.partial(cls, **kwargs)
  73. return super().__new__(cls)
  74. def __init__(
  75. self,
  76. func: Callable[..., Union[None, Tensor, Tuple[Tensor]]],
  77. *,
  78. symbolic: bool = False,
  79. opt_level: int = None,
  80. log_level: int = None,
  81. profiling: bool = False
  82. ):
  83. self.__wrapped__ = func
  84. self._symbolic = symbolic
  85. self._graph_opt_level = opt_level
  86. self._log_level = log_level
  87. self._status = self._UNSTARTED
  88. self._args = None
  89. self._kwargs = None
  90. self._outputs = unset
  91. self._sym_outputs = unset
  92. self._outspec = None
  93. self._checkpoint = None
  94. self._compiled_func = None
  95. self._profiling = profiling
  96. self._profiler = None
  97. @property
  98. def _active(self):
  99. c1 = self._status == self._STARTED
  100. c2 = type(self)._active_instance is self
  101. assert c1 == c2
  102. return c1
  103. def _register_callback(self, f, args=()):
  104. assert self._active
  105. assert isinstance(args, (tuple, list))
  106. proxies = self._make_proxies(args)
  107. self._forward(args, proxies, checkpoint=True)
  108. # NOTE: under eager graph callback will fire immediately
  109. job = mgb.opr.callback_injector(
  110. self._insert_barrier(_dummy()), lambda _: f(*proxies)
  111. )
  112. self._insert_checkpoint(job)
  113. self._outspec.append(job)
  114. def _insert_barrier(self, x):
  115. assert self._active
  116. if self._checkpoint is None:
  117. return x
  118. if isinstance(x, Tensor):
  119. x = x._symvar
  120. wrap = True
  121. else:
  122. wrap = False
  123. if not isinstance(x, mgb.SymbolVar):
  124. raise TypeError
  125. x = mgb.opr.virtual_dep([x, self._checkpoint])
  126. if wrap:
  127. x = Tensor(x)
  128. return x
  129. def _insert_checkpoint(self, *args, no_barrier=False):
  130. assert self._active
  131. if not args:
  132. return
  133. args = tuple(x._symvar if isinstance(x, Tensor) else x for x in args)
  134. for x in args:
  135. if not isinstance(x, mgb.SymbolVar):
  136. raise TypeError
  137. if not no_barrier and self._checkpoint is not None:
  138. # normally no need to _insert_barrier here, but if
  139. # someone forget to call _insert_barrier beforehand,
  140. # this can make things less broken
  141. args += (self._checkpoint,)
  142. if len(args) == 1:
  143. self._checkpoint = args[0]
  144. else:
  145. self._checkpoint = mgb.opr.virtual_dep(args)
  146. def _mark_impure(self, x):
  147. assert self._active
  148. ret = x
  149. if isinstance(x, Tensor):
  150. x = x._symvar
  151. if not isinstance(x, mgb.SymbolVar):
  152. raise TypeError
  153. self._outspec.append(x)
  154. self._insert_checkpoint(x)
  155. return ret
  156. def _make_proxies(self, args):
  157. assert isinstance(args, (tuple, list))
  158. for x in args:
  159. assert isinstance(x, Tensor)
  160. return tuple(tensor(dtype=x.dtype, device=x.device) for x in args)
  161. def _forward(self, srcs, dests, checkpoint=True):
  162. # pseudo-op: does not run under static graph; traced
  163. # TODO: use shared memory
  164. assert len(srcs) == len(dests)
  165. if not self._active:
  166. for s, d in zip(srcs, dests):
  167. d.set_value(s, share=False)
  168. return
  169. jobs = []
  170. for s, d in zip(srcs, dests):
  171. def callback(value, dest=d):
  172. dest.set_value(value, share=False)
  173. s = self._insert_barrier(s._symvar)
  174. # NOTE: callback immediately fire in eager graph
  175. jobs.append(mgb.opr.callback_injector(s, callback))
  176. self._outspec.extend(jobs)
  177. if checkpoint:
  178. self._insert_checkpoint(*jobs, no_barrier=True)
  179. def _forward_inputs(self, *args, **kwargs):
  180. if self._kwargs is None:
  181. self._kwargs = kwargs
  182. elif self._kwargs != kwargs:
  183. raise ValueError("kwargs must not change between invocations")
  184. if self._args is None:
  185. self._args = []
  186. for i in args:
  187. if isinstance(i, Tensor):
  188. self._args.append(tensor(dtype=i.dtype, device=i.device))
  189. self._args[-1].set_value(i, share=False)
  190. else:
  191. self._args.append(tensor(i))
  192. else:
  193. if not len(args) == len(self._args):
  194. raise TypeError
  195. for i, proxy in zip(args, self._args):
  196. proxy.set_value(i, share=False)
  197. # XXX: sync?
  198. def _make_outputs(self, outputs):
  199. if outputs is None:
  200. self._outputs = None
  201. return
  202. if isinstance(outputs, Tensor):
  203. # no one is able to call barrier after this, so no need to checkpoint
  204. # but checkpoint do little harm anyway
  205. (self._outputs,) = self._make_proxies([outputs])
  206. return
  207. if not isinstance(outputs, (tuple, list)):
  208. raise TypeError("should return (tuple of) tensor")
  209. for i in outputs:
  210. if not isinstance(i, Tensor):
  211. raise TypeError("should return (tuple of) tensor")
  212. self._outputs = self._make_proxies(outputs)
  213. def _foward_outputs(self, outputs):
  214. # pseudo-op: does not run under static graph; traced
  215. if self._outputs is unset:
  216. self._make_outputs(outputs)
  217. if self._outputs is None:
  218. if outputs is not None:
  219. raise TypeError("should return None")
  220. elif isinstance(self._outputs, Tensor):
  221. if not isinstance(outputs, Tensor):
  222. raise TypeError("should return a tensor")
  223. self._forward([outputs], [self._outputs])
  224. else:
  225. assert isinstance(self._outputs, tuple)
  226. def check():
  227. if not isinstance(outputs, (tuple, list)):
  228. return False
  229. if len(self._outputs) != len(outputs):
  230. return False
  231. for x in outputs:
  232. if not isinstance(x, Tensor):
  233. return False
  234. return True
  235. if not check():
  236. raise TypeError(
  237. "should return tuple of %d tensors" % len(self._outputs)
  238. )
  239. self._forward(outputs, self._outputs)
  240. def _apply_graph_options(self, cg):
  241. # graph opt level
  242. if not self._graph_opt_level is None:
  243. cg.set_option("graph_opt_level", self._graph_opt_level)
  244. # log level
  245. if not self._log_level is None:
  246. cg.set_option("log_level", self._log_level)
  247. # profile
  248. if self._profiling:
  249. self._profiler = CompGraphProfiler(cg)
  250. def _get_graph(self, eager):
  251. if eager:
  252. if not hasattr(self, "_eager_graph"):
  253. # pylint: disable=attribute-defined-outside-init
  254. self._eager_graph = graph.Graph(eager_evaluation=True)
  255. self._apply_graph_options(self._eager_graph)
  256. return self._eager_graph
  257. else:
  258. if not hasattr(self, "_static_graph"):
  259. # pylint: disable=attribute-defined-outside-init
  260. self._static_graph = graph.Graph(eager_evaluation=False)
  261. self._apply_graph_options(self._static_graph)
  262. return self._static_graph
  263. @contextlib.contextmanager
  264. def _prepare(self, args, kwargs, enable):
  265. # prepare for execution
  266. self._forward_inputs(*args, **kwargs)
  267. if not enable:
  268. # XXX: use our own graph here?
  269. cg = None
  270. elif self._status == self._FINISHED:
  271. cg = None
  272. elif self._symbolic:
  273. cg = self._get_graph(eager=False)
  274. else:
  275. cg = self._get_graph(eager=True)
  276. try:
  277. # NOTE: always trace in a new graph, so capturing an undetached tensor
  278. # will never work (would work if tracing in default graph)
  279. if cg is None:
  280. yield
  281. else:
  282. with cg:
  283. yield
  284. finally:
  285. # XXX: properly release memory
  286. if cg:
  287. cg.clear_device_memory()
  288. @contextlib.contextmanager
  289. def _activate(self):
  290. # prepare for tracing
  291. if self._status != self._UNSTARTED:
  292. raise RuntimeError("cannot trace a second time")
  293. if type(self)._active_instance is not None:
  294. raise RuntimeError("nested trace is unsupported")
  295. self._status = self._STARTED
  296. type(self)._active_instance = self
  297. try:
  298. yield
  299. finally:
  300. self._status = self._FINISHED
  301. type(self)._active_instance = None
  302. def _run_wrapped(self):
  303. outputs = self.__wrapped__(*self._args, **self._kwargs)
  304. self._foward_outputs(outputs)
  305. return outputs
  306. def _do_trace(self):
  307. with self._activate():
  308. self._outspec = []
  309. outputs = self._run_wrapped()
  310. if outputs is None:
  311. self._sym_outputs = None
  312. else:
  313. if isinstance(outputs, Tensor):
  314. outputs = [outputs]
  315. # _run_wrapped has checked validity of outputs
  316. self._sym_outputs = tuple(i._symvar for i in outputs)
  317. self._compiled_func = graph.get_default_graph().compile(None, self._outspec)
  318. def trace(self, *args: Tensor, **kwargs):
  319. """
  320. Trace wrapped callable with provided arguments.
  321. """
  322. with self._prepare(args, kwargs, enable=True):
  323. self._do_trace()
  324. return self
  325. def __call__(self, *args: Tensor, **kwargs):
  326. """
  327. Evaluate on provided arguments, using compiled trace
  328. instead of the original callable if applicable.
  329. :return: ``None`` or :class:`~.Tensor` or tuple of :class:`~.Tensor`, depending on the
  330. return value of wrapped callable.
  331. """
  332. with self._prepare(args, kwargs, enable=self.enabled):
  333. if not self.enabled:
  334. self._run_wrapped()
  335. elif self._status == self._FINISHED:
  336. self._compiled_func()
  337. else:
  338. if self._status == self._UNSTARTED:
  339. self._do_trace()
  340. if self._symbolic:
  341. self._compiled_func()
  342. return self._outputs
  343. def dump(
  344. self,
  345. fpath,
  346. *,
  347. arg_names=None,
  348. append=False,
  349. optimize_for_inference=False,
  350. **kwargs
  351. ):
  352. """
  353. Serialize trace to file system.
  354. :param fpath: positional only argument. Path of output file.
  355. :param arg_names: names of the input tensors in the traced function
  356. :param append: whether output is appended to ``fpath``
  357. :param f16_io_f32_comp: whether to use float16 for I/O between oprs and use
  358. float32 as internal computation precision. Note the output var would be
  359. changed to float16
  360. :param f16_io_comp: whether to use float16 for both I/O and computation
  361. precision
  362. :param use_nhwcd4: whether to use NHWCD4 data format. This is faster on some
  363. OpenCL devices
  364. :param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  365. into one opr. This is supported only in NHWCD4 format.
  366. """
  367. if self._status != self._FINISHED:
  368. raise ValueError("not traced")
  369. assert isinstance(self._sym_outputs, (tuple, type(None)))
  370. if not self._sym_outputs:
  371. raise ValueError("not outputs")
  372. if arg_names is None:
  373. arg_names = ["arg_%d" % i for i in range(len(self._args))]
  374. elif len(arg_names) != len(self._args):
  375. raise ValueError(
  376. "len(arg_names) should be {}, got {}".format(
  377. len(self._args), len(arg_names)
  378. )
  379. )
  380. optimize_for_inference_args_map = {
  381. "enable_io16xc32": "f16_io_f32_comp",
  382. "enable_ioc16": "f16_io_comp",
  383. "enable_hwcd4": "use_nhwcd4",
  384. "enable_nchw88": "use_nchw88",
  385. "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity",
  386. "enable_tensorcore": "use_tensor_core",
  387. "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z",
  388. }
  389. if optimize_for_inference:
  390. optimize_for_inference_kwargs = {}
  391. for k, v in optimize_for_inference_args_map.items():
  392. if kwargs.pop(k, False):
  393. optimize_for_inference_kwargs[v] = True
  394. else:
  395. for k in optimize_for_inference_args_map:
  396. if kwargs.get(k, False):
  397. raise ValueError(
  398. "cannot set %s when optimize_for_inference is not set" % k
  399. )
  400. if kwargs:
  401. raise ValueError("unknown options: %s" % list(kwargs))
  402. cg = self._sym_outputs[0].owner_graph
  403. replace = {}
  404. for t, name in zip(self._args, arg_names):
  405. # relies on symvar dedup
  406. s = t.__mgb_symvar__(comp_graph=cg)
  407. replace[s] = mgb.make_arg(
  408. t.device, cg, dtype=t.dtype, shape=t.shape, name=name
  409. )
  410. # Convert VolatileSharedDeviceTensor to SharedDeviceTensor,
  411. # otherwise some optimizations would not work. The conversion is
  412. # safe because there simply is no way (using builtin ops) to make
  413. # a VolatileSharedDeviceTensor actually volatile.
  414. for s in mgb.cgtools.get_dep_vars(
  415. self._sym_outputs, "VolatileSharedDeviceTensor"
  416. ):
  417. if s in replace:
  418. continue # is an input
  419. replace[s] = mgb.SharedND._from_symvar(s).symvar(
  420. cg, name=s.name, volatile=False
  421. )
  422. sym_outputs = mgb.cgtools.replace_vars(self._sym_outputs, replace)
  423. sym_outputs = list(sym_outputs)
  424. if optimize_for_inference:
  425. sym_outputs = mgb.optimize_for_inference(
  426. sym_outputs, **optimize_for_inference_kwargs
  427. )
  428. mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append)
  429. def get_profile(self):
  430. """
  431. Get profiling result for compiled trace.
  432. :return: a json compatible object.
  433. """
  434. if not self._profiler:
  435. raise RuntimeError("trace is not set with profiling=True")
  436. return self._profiler.get()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台