You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import contextlib
  10. import functools
  11. import itertools
  12. import os
  13. from typing import Callable, Tuple, Union
  14. import numpy as np
  15. import megengine._internal as mgb
  16. from megengine._internal.plugin import CompGraphProfiler
  17. from ..core import Tensor, graph, tensor
  18. from .sublinear_memory_config import SublinearMemConfig
  19. def sideeffect(f):
  20. # during eager tracing, wrapped function is called with proxy inputs
  21. # during static tracing, wrapped function will not be called at all
  22. @functools.wraps(f)
  23. def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
  24. if not trace._active_instance:
  25. return f(*args, **kwargs)
  26. tensors = {}
  27. for i, x in itertools.chain(enumerate(args), kwargs.items()):
  28. if isinstance(x, Tensor):
  29. tensors[i] = x
  30. if tensors:
  31. _keys, tensors = zip(*tensors.items())
  32. else:
  33. _keys, tensors = (), ()
  34. def callback(*tensors, f=f, keys=_keys, args=args, kwargs=kwargs):
  35. replace = dict(zip(keys, tensors))
  36. args = tuple(replace.get(i, x) for i, x in enumerate(args))
  37. kwargs = {i: replace.get(i, x) for i, x in kwargs.items()}
  38. if f(*args, **kwargs) is not None:
  39. raise TypeError("a sideeffect function should return None")
  40. # TODO: clear memory
  41. trace._active_instance._register_callback(callback, tensors)
  42. return wrapper
  43. def mark_impure(x):
  44. if not trace._active_instance:
  45. return x
  46. return trace._active_instance._mark_impure(x)
  47. def barrier(x):
  48. if not trace._active_instance:
  49. return x
  50. return trace._active_instance._insert_barrier(x)
  51. def _dummy():
  52. return mgb.make_immutable(*graph._use_default_if_none(None, None), 0)
  53. class unset:
  54. pass
  55. class trace:
  56. """
  57. Wrap a callable and provide:
  58. * tracing via :meth:`.trace` and :meth:`.dump`
  59. * accelerated evalutaion via :meth:`.__call__`
  60. :param func: Positional only argument.
  61. :param symbolic: Whether to use symbolic tensor. Default: False
  62. :param opt_level: Optimization level for compiling trace.
  63. :param log_level: Log level.
  64. :param enable_sublinear: Enable sublinear memory optimization. Default: False
  65. :param sublinear_mem_config: Configuration for sublinear memory optimization.
  66. :param profiling: Whether to profile compiled trace. Default: False
  67. """
  68. _active_instance = None
  69. enabled = not os.getenv("MGE_DISABLE_TRACE")
  70. _UNSTARTED = "unstarted"
  71. _STARTED = "started"
  72. _FINISHED = "finished"
  73. def __new__(cls, *args, **kwargs):
  74. if not args:
  75. return functools.partial(cls, **kwargs)
  76. return super().__new__(cls)
  77. def __init__(
  78. self,
  79. func: Callable[..., Union[None, Tensor, Tuple[Tensor]]],
  80. *,
  81. symbolic: bool = False,
  82. opt_level: int = None,
  83. log_level: int = None,
  84. enable_sublinear: bool = False,
  85. sublinear_mem_config: SublinearMemConfig = None,
  86. profiling: bool = False
  87. ):
  88. self.__wrapped__ = func
  89. self._symbolic = symbolic
  90. self._graph_opt_level = opt_level
  91. self._log_level = log_level
  92. self._enable_sublinear = enable_sublinear
  93. self._sublinear_mem_config = sublinear_mem_config
  94. self._status = self._UNSTARTED
  95. self._args = None
  96. self._kwargs = None
  97. self._outputs = unset
  98. self._sym_outputs = unset
  99. self._outspec = None
  100. self._checkpoint = None
  101. self._compiled_func = None
  102. self._profiling = profiling
  103. self._profiler = None
  104. @property
  105. def _active(self):
  106. c1 = self._status == self._STARTED
  107. c2 = type(self)._active_instance is self
  108. assert c1 == c2
  109. return c1
  110. def _register_callback(self, f, args=()):
  111. assert self._active
  112. assert isinstance(args, (tuple, list))
  113. proxies = self._make_proxies(args)
  114. self._forward(args, proxies, checkpoint=True)
  115. # NOTE: under eager graph callback will fire immediately
  116. job = mgb.opr.callback_injector(
  117. self._insert_barrier(_dummy()), lambda _: f(*proxies)
  118. )
  119. self._insert_checkpoint(job)
  120. self._outspec.append(job)
  121. def _insert_barrier(self, x):
  122. assert self._active
  123. if self._checkpoint is None:
  124. return x
  125. if isinstance(x, Tensor):
  126. x = x._symvar
  127. wrap = True
  128. else:
  129. wrap = False
  130. if not isinstance(x, mgb.SymbolVar):
  131. raise TypeError
  132. x = mgb.opr.virtual_dep([x, self._checkpoint])
  133. if wrap:
  134. x = Tensor(x)
  135. return x
  136. def _insert_checkpoint(self, *args, no_barrier=False):
  137. assert self._active
  138. if not args:
  139. return
  140. args = tuple(x._symvar if isinstance(x, Tensor) else x for x in args)
  141. for x in args:
  142. if not isinstance(x, mgb.SymbolVar):
  143. raise TypeError
  144. if not no_barrier and self._checkpoint is not None:
  145. # normally no need to _insert_barrier here, but if
  146. # someone forget to call _insert_barrier beforehand,
  147. # this can make things less broken
  148. args += (self._checkpoint,)
  149. if len(args) == 1:
  150. self._checkpoint = args[0]
  151. else:
  152. self._checkpoint = mgb.opr.virtual_dep(args)
  153. def _mark_impure(self, x):
  154. assert self._active
  155. ret = x
  156. if isinstance(x, Tensor):
  157. x = x._symvar
  158. if not isinstance(x, mgb.SymbolVar):
  159. raise TypeError
  160. self._outspec.append(x)
  161. self._insert_checkpoint(x)
  162. return ret
  163. def _make_proxies(self, args):
  164. assert isinstance(args, (tuple, list))
  165. for x in args:
  166. assert isinstance(x, Tensor)
  167. return tuple(tensor(dtype=x.dtype, device=x.device) for x in args)
  168. def _forward(self, srcs, dests, checkpoint=True):
  169. # pseudo-op: does not run under static graph; traced
  170. # TODO: use shared memory
  171. assert len(srcs) == len(dests)
  172. if not self._active:
  173. for s, d in zip(srcs, dests):
  174. d.set_value(s, share=False)
  175. return
  176. jobs = []
  177. for s, d in zip(srcs, dests):
  178. def callback(value, dest=d):
  179. dest.set_value(value, share=False)
  180. s = self._insert_barrier(s._symvar)
  181. # NOTE: callback immediately fire in eager graph
  182. jobs.append(mgb.opr.callback_injector(s, callback))
  183. self._outspec.extend(jobs)
  184. if checkpoint:
  185. self._insert_checkpoint(*jobs, no_barrier=True)
  186. def _forward_inputs(self, *args, **kwargs):
  187. if self._kwargs is None:
  188. self._kwargs = kwargs
  189. elif self._kwargs != kwargs:
  190. raise ValueError("kwargs must not change between invocations")
  191. if self._args is None:
  192. self._args = []
  193. for i in args:
  194. if isinstance(i, Tensor):
  195. self._args.append(tensor(dtype=i.dtype, device=i.device))
  196. self._args[-1].set_value(i, share=False)
  197. else:
  198. self._args.append(tensor(i))
  199. else:
  200. if not len(args) == len(self._args):
  201. raise TypeError
  202. for i, proxy in zip(args, self._args):
  203. proxy.set_value(i, share=False)
  204. # XXX: sync?
  205. def _make_outputs(self, outputs):
  206. if outputs is None:
  207. self._outputs = None
  208. return
  209. if isinstance(outputs, Tensor):
  210. # no one is able to call barrier after this, so no need to checkpoint
  211. # but checkpoint do little harm anyway
  212. (self._outputs,) = self._make_proxies([outputs])
  213. return
  214. if not isinstance(outputs, (tuple, list)):
  215. raise TypeError("should return (tuple of) tensor")
  216. for i in outputs:
  217. if not isinstance(i, Tensor):
  218. raise TypeError("should return (tuple of) tensor")
  219. self._outputs = self._make_proxies(outputs)
  220. def _foward_outputs(self, outputs):
  221. # pseudo-op: does not run under static graph; traced
  222. if self._outputs is unset:
  223. self._make_outputs(outputs)
  224. if self._outputs is None:
  225. if outputs is not None:
  226. raise TypeError("should return None")
  227. elif isinstance(self._outputs, Tensor):
  228. if not isinstance(outputs, Tensor):
  229. raise TypeError("should return a tensor")
  230. self._forward([outputs], [self._outputs])
  231. else:
  232. assert isinstance(self._outputs, tuple)
  233. def check():
  234. if not isinstance(outputs, (tuple, list)):
  235. return False
  236. if len(self._outputs) != len(outputs):
  237. return False
  238. for x in outputs:
  239. if not isinstance(x, Tensor):
  240. return False
  241. return True
  242. if not check():
  243. raise TypeError(
  244. "should return tuple of %d tensors" % len(self._outputs)
  245. )
  246. self._forward(outputs, self._outputs)
  247. def _apply_graph_options(self, cg):
  248. # graph opt level
  249. if not (self._graph_opt_level is None):
  250. cg.set_option("graph_opt_level", self._graph_opt_level)
  251. # log level
  252. if not (self._log_level is None):
  253. cg.set_option("log_level", self._log_level)
  254. # sublinear
  255. if self._enable_sublinear:
  256. cg.set_option("enable_sublinear_memory_opt", True)
  257. if not (self._sublinear_mem_config is None):
  258. cg.set_option(
  259. "sublinear_mem_cofig.lb_memory",
  260. self._sublinear_mem_config.lb_memory,
  261. )
  262. cg.set_option(
  263. "sublinear_mem_cofig.genetic_nr_iter",
  264. self._sublinear_mem_config.genetic_nr_iter,
  265. )
  266. cg.set_option(
  267. "sublinear_mem_cofig.genetic_pool_size",
  268. self._sublinear_mem_config.genetic_pool_size,
  269. )
  270. cg.set_option(
  271. "sublinear_mem_cofig.thresh_nr_try",
  272. self._sublinear_mem_config.thresh_nr_try,
  273. )
  274. cg.set_option(
  275. "sublinear_mem_cofig.num_worker",
  276. self._sublinear_mem_config.num_worker,
  277. )
  278. # profile
  279. if self._profiling:
  280. self._profiler = CompGraphProfiler(cg)
  281. def _get_graph(self, eager):
  282. if eager:
  283. if not hasattr(self, "_eager_graph"):
  284. # pylint: disable=attribute-defined-outside-init
  285. self._eager_graph = graph.Graph(eager_evaluation=True)
  286. self._apply_graph_options(self._eager_graph)
  287. return self._eager_graph
  288. else:
  289. if not hasattr(self, "_static_graph"):
  290. # pylint: disable=attribute-defined-outside-init
  291. self._static_graph = graph.Graph(eager_evaluation=False)
  292. self._apply_graph_options(self._static_graph)
  293. return self._static_graph
  294. @contextlib.contextmanager
  295. def _prepare(self, args, kwargs, enable):
  296. # prepare for execution
  297. self._forward_inputs(*args, **kwargs)
  298. if not enable:
  299. # XXX: use our own graph here?
  300. cg = None
  301. elif self._status == self._FINISHED:
  302. cg = None
  303. elif self._symbolic:
  304. cg = self._get_graph(eager=False)
  305. else:
  306. cg = self._get_graph(eager=True)
  307. try:
  308. # NOTE: always trace in a new graph, so capturing an undetached tensor
  309. # will never work (would work if tracing in default graph)
  310. if cg is None:
  311. yield
  312. else:
  313. with cg:
  314. yield
  315. finally:
  316. # XXX: properly release memory
  317. if cg:
  318. cg.clear_device_memory()
  319. @contextlib.contextmanager
  320. def _activate(self):
  321. # prepare for tracing
  322. if self._status != self._UNSTARTED:
  323. raise RuntimeError("cannot trace a second time")
  324. if type(self)._active_instance is not None:
  325. raise RuntimeError("nested trace is unsupported")
  326. self._status = self._STARTED
  327. type(self)._active_instance = self
  328. try:
  329. yield
  330. finally:
  331. self._status = self._FINISHED
  332. type(self)._active_instance = None
  333. def _run_wrapped(self):
  334. outputs = self.__wrapped__(*self._args, **self._kwargs)
  335. self._foward_outputs(outputs)
  336. return outputs
  337. def _do_trace(self):
  338. with self._activate():
  339. self._outspec = []
  340. outputs = self._run_wrapped()
  341. if outputs is None:
  342. self._sym_outputs = None
  343. else:
  344. if isinstance(outputs, Tensor):
  345. outputs = [outputs]
  346. # _run_wrapped has checked validity of outputs
  347. self._sym_outputs = tuple(i._symvar for i in outputs)
  348. self._compiled_func = graph.get_default_graph().compile(None, self._outspec)
  349. def trace(self, *args: Tensor, **kwargs):
  350. """
  351. Trace wrapped callable with provided arguments.
  352. """
  353. with self._prepare(args, kwargs, enable=True):
  354. self._do_trace()
  355. return self
  356. def __call__(self, *args: Tensor, **kwargs):
  357. """
  358. Evaluate on provided arguments, using compiled trace
  359. instead of the original callable if applicable.
  360. :return: ``None`` or :class:`~.Tensor` or tuple of :class:`~.Tensor`, depending on the
  361. return value of wrapped callable.
  362. """
  363. with self._prepare(args, kwargs, enable=self.enabled):
  364. if not self.enabled:
  365. self._run_wrapped()
  366. elif self._status == self._FINISHED:
  367. self._compiled_func()
  368. else:
  369. if self._status == self._UNSTARTED:
  370. self._do_trace()
  371. if self._symbolic:
  372. self._compiled_func()
  373. return self._outputs
  374. def dump(
  375. self,
  376. fpath,
  377. *,
  378. arg_names=None,
  379. append=False,
  380. optimize_for_inference=False,
  381. **kwargs
  382. ):
  383. """
  384. Serialize trace to file system.
  385. :param fpath: positional only argument. Path of output file.
  386. :param arg_names: names of the input tensors in the traced function
  387. :param append: whether output is appended to ``fpath``
  388. :param f16_io_f32_comp: whether to use float16 for I/O between oprs and use
  389. float32 as internal computation precision. Note the output var would be
  390. changed to float16
  391. :param f16_io_comp: whether to use float16 for both I/O and computation
  392. precision
  393. :param use_nhwcd4: whether to use NHWCD4 data format. This is faster on some
  394. OpenCL devices
  395. :param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  396. into one opr. This is supported only in NHWCD4 format.
  397. """
  398. if self._status != self._FINISHED:
  399. raise ValueError("not traced")
  400. assert isinstance(self._sym_outputs, (tuple, type(None)))
  401. if not self._sym_outputs:
  402. raise ValueError("not outputs")
  403. if arg_names is None:
  404. arg_names = ["arg_%d" % i for i in range(len(self._args))]
  405. elif len(arg_names) != len(self._args):
  406. raise ValueError(
  407. "len(arg_names) should be {}, got {}".format(
  408. len(self._args), len(arg_names)
  409. )
  410. )
  411. optimize_for_inference_args_map = {
  412. "enable_io16xc32": "f16_io_f32_comp",
  413. "enable_ioc16": "f16_io_comp",
  414. "enable_hwcd4": "use_nhwcd4",
  415. "enable_nchw88": "use_nchw88",
  416. "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity",
  417. "enable_tensorcore": "use_tensor_core",
  418. "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z",
  419. }
  420. if optimize_for_inference:
  421. optimize_for_inference_kwargs = {}
  422. for k, v in optimize_for_inference_args_map.items():
  423. if kwargs.pop(k, False):
  424. optimize_for_inference_kwargs[v] = True
  425. else:
  426. for k in optimize_for_inference_args_map:
  427. if kwargs.get(k, False):
  428. raise ValueError(
  429. "cannot set %s when optimize_for_inference is not set" % k
  430. )
  431. if kwargs:
  432. raise ValueError("unknown options: %s" % list(kwargs))
  433. cg = self._sym_outputs[0].owner_graph
  434. replace = {}
  435. for t, name in zip(self._args, arg_names):
  436. # relies on symvar dedup
  437. s = t.__mgb_symvar__(comp_graph=cg)
  438. replace[s] = mgb.make_arg(
  439. t.device, cg, dtype=t.dtype, shape=t.shape, name=name
  440. )
  441. # Convert VolatileSharedDeviceTensor to SharedDeviceTensor,
  442. # otherwise some optimizations would not work. The conversion is
  443. # safe because there simply is no way (using builtin ops) to make
  444. # a VolatileSharedDeviceTensor actually volatile.
  445. for s in mgb.cgtools.get_dep_vars(
  446. self._sym_outputs, "VolatileSharedDeviceTensor"
  447. ):
  448. if s in replace:
  449. continue # is an input
  450. replace[s] = mgb.SharedND._from_symvar(s).symvar(
  451. cg, name=s.name, volatile=False
  452. )
  453. sym_outputs = mgb.cgtools.replace_vars(self._sym_outputs, replace)
  454. sym_outputs = list(sym_outputs)
  455. if optimize_for_inference:
  456. sym_outputs = mgb.optimize_for_inference(
  457. sym_outputs, **optimize_for_inference_kwargs
  458. )
  459. mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append)
  460. def get_profile(self):
  461. """
  462. Get profiling result for compiled trace.
  463. :return: a json compatible object.
  464. """
  465. if not self._profiler:
  466. raise RuntimeError("trace is not set with profiling=True")
  467. return self._profiler.get()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台