You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import contextlib
  10. import functools
  11. import itertools
  12. import os
  13. from typing import Callable, Tuple, Union
  14. import numpy as np
  15. import megengine._internal as mgb
  16. from megengine._internal.plugin import CompGraphProfiler
  17. from ..core import Tensor, graph, tensor
  18. from .sublinear_memory_config import SublinearMemoryConfig
  19. def sideeffect(f):
  20. # during eager tracing, wrapped function is called with proxy inputs
  21. # during static tracing, wrapped function will not be called at all
  22. @functools.wraps(f)
  23. def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
  24. if not trace._active_instance:
  25. return f(*args, **kwargs)
  26. tensors = {}
  27. for i, x in itertools.chain(enumerate(args), kwargs.items()):
  28. if isinstance(x, Tensor):
  29. tensors[i] = x
  30. if tensors:
  31. _keys, tensors = zip(*tensors.items())
  32. else:
  33. _keys, tensors = (), ()
  34. def callback(*tensors, f=f, keys=_keys, args=args, kwargs=kwargs):
  35. replace = dict(zip(keys, tensors))
  36. args = tuple(replace.get(i, x) for i, x in enumerate(args))
  37. kwargs = {i: replace.get(i, x) for i, x in kwargs.items()}
  38. if f(*args, **kwargs) is not None:
  39. raise TypeError("a sideeffect function should return None")
  40. # TODO: clear memory
  41. trace._active_instance._register_callback(callback, tensors)
  42. return wrapper
  43. def mark_impure(x):
  44. if not trace._active_instance:
  45. return x
  46. return trace._active_instance._mark_impure(x)
  47. def barrier(x):
  48. if not trace._active_instance:
  49. return x
  50. return trace._active_instance._insert_barrier(x)
  51. def _dummy():
  52. return mgb.make_immutable(*graph._use_default_if_none(None, None), 0)
  53. class unset:
  54. pass
  55. class trace:
  56. """
  57. Wrap a callable and provide:
  58. * tracing via :meth:`.trace` and :meth:`.dump`
  59. * accelerated evalutaion via :meth:`.__call__`
  60. :param func: Positional only argument.
  61. :param symbolic: Whether to use symbolic tensor. Default: False
  62. :param opt_level: Optimization level for compiling trace.
  63. :param log_level: Log level.
  64. :param sublinear_memory_config: Configuration for sublinear memory optimization.
  65. If not None, it enables sublinear memory optimization with given setting.
  66. :param profiling: Whether to profile compiled trace. Default: False
  67. """
  68. _active_instance = None
  69. enabled = not os.getenv("MGE_DISABLE_TRACE")
  70. _UNSTARTED = "unstarted"
  71. _STARTED = "started"
  72. _FINISHED = "finished"
  73. def __new__(cls, *args, **kwargs):
  74. if not args:
  75. return functools.partial(cls, **kwargs)
  76. return super().__new__(cls)
  77. def __init__(
  78. self,
  79. func: Callable[..., Union[None, Tensor, Tuple[Tensor]]],
  80. *,
  81. symbolic: bool = False,
  82. opt_level: int = None,
  83. log_level: int = None,
  84. sublinear_memory_config: SublinearMemoryConfig = None,
  85. profiling: bool = False
  86. ):
  87. self.__wrapped__ = func
  88. self._symbolic = symbolic
  89. self._graph_opt_level = opt_level
  90. self._log_level = log_level
  91. self._sublinear_memory_config = sublinear_memory_config
  92. self._status = self._UNSTARTED
  93. self._args = None
  94. self._kwargs = None
  95. self._outputs = unset
  96. self._sym_outputs = unset
  97. self._outspec = None
  98. self._checkpoint = None
  99. self._compiled_func = None
  100. self._profiling = profiling
  101. self._profiler = None
  102. @property
  103. def _active(self):
  104. c1 = self._status == self._STARTED
  105. c2 = type(self)._active_instance is self
  106. assert c1 == c2
  107. return c1
  108. def _register_callback(self, f, args=()):
  109. assert self._active
  110. assert isinstance(args, (tuple, list))
  111. proxies = self._make_proxies(args)
  112. self._forward(args, proxies, checkpoint=True)
  113. # NOTE: under eager graph callback will fire immediately
  114. job = mgb.opr.callback_injector(
  115. self._insert_barrier(_dummy()), lambda _: f(*proxies)
  116. )
  117. self._insert_checkpoint(job)
  118. self._outspec.append(job)
  119. def _insert_barrier(self, x):
  120. assert self._active
  121. if self._checkpoint is None:
  122. return x
  123. if isinstance(x, Tensor):
  124. x = x._symvar
  125. wrap = True
  126. else:
  127. wrap = False
  128. if not isinstance(x, mgb.SymbolVar):
  129. raise TypeError
  130. x = mgb.opr.virtual_dep([x, self._checkpoint])
  131. if wrap:
  132. x = Tensor(x)
  133. return x
  134. def _insert_checkpoint(self, *args, no_barrier=False):
  135. assert self._active
  136. if not args:
  137. return
  138. args = tuple(x._symvar if isinstance(x, Tensor) else x for x in args)
  139. for x in args:
  140. if not isinstance(x, mgb.SymbolVar):
  141. raise TypeError
  142. if not no_barrier and self._checkpoint is not None:
  143. # normally no need to _insert_barrier here, but if
  144. # someone forget to call _insert_barrier beforehand,
  145. # this can make things less broken
  146. args += (self._checkpoint,)
  147. if len(args) == 1:
  148. self._checkpoint = args[0]
  149. else:
  150. self._checkpoint = mgb.opr.virtual_dep(args)
  151. def _mark_impure(self, x):
  152. assert self._active
  153. ret = x
  154. if isinstance(x, Tensor):
  155. x = x._symvar
  156. if not isinstance(x, mgb.SymbolVar):
  157. raise TypeError
  158. self._outspec.append(x)
  159. self._insert_checkpoint(x)
  160. return ret
  161. def _make_proxies(self, args):
  162. assert isinstance(args, (tuple, list))
  163. for x in args:
  164. assert isinstance(x, Tensor)
  165. return tuple(tensor(dtype=x.dtype, device=x.device) for x in args)
  166. def _forward(self, srcs, dests, checkpoint=True):
  167. # pseudo-op: does not run under static graph; traced
  168. # TODO: use shared memory
  169. assert len(srcs) == len(dests)
  170. if not self._active:
  171. for s, d in zip(srcs, dests):
  172. d.set_value(s, share=False)
  173. return
  174. jobs = []
  175. for s, d in zip(srcs, dests):
  176. def callback(value, dest=d):
  177. dest.set_value(value, share=False)
  178. s = self._insert_barrier(s._symvar)
  179. # NOTE: callback immediately fire in eager graph
  180. jobs.append(mgb.opr.callback_injector(s, callback))
  181. self._outspec.extend(jobs)
  182. if checkpoint:
  183. self._insert_checkpoint(*jobs, no_barrier=True)
  184. def _forward_inputs(self, *args, **kwargs):
  185. if self._kwargs is None:
  186. self._kwargs = kwargs
  187. elif self._kwargs != kwargs:
  188. raise ValueError("kwargs must not change between invocations")
  189. if self._args is None:
  190. self._args = []
  191. for i in args:
  192. if isinstance(i, Tensor):
  193. self._args.append(tensor(dtype=i.dtype, device=i.device))
  194. self._args[-1].set_value(i, share=False)
  195. else:
  196. self._args.append(tensor(i))
  197. else:
  198. if not len(args) == len(self._args):
  199. raise TypeError
  200. for i, proxy in zip(args, self._args):
  201. proxy.set_value(i, share=False)
  202. # XXX: sync?
  203. def _make_outputs(self, outputs):
  204. if outputs is None:
  205. self._outputs = None
  206. return
  207. if isinstance(outputs, Tensor):
  208. # no one is able to call barrier after this, so no need to checkpoint
  209. # but checkpoint do little harm anyway
  210. (self._outputs,) = self._make_proxies([outputs])
  211. return
  212. if not isinstance(outputs, (tuple, list)):
  213. raise TypeError("should return (tuple of) tensor")
  214. for i in outputs:
  215. if not isinstance(i, Tensor):
  216. raise TypeError("should return (tuple of) tensor")
  217. self._outputs = self._make_proxies(outputs)
  218. def _foward_outputs(self, outputs):
  219. # pseudo-op: does not run under static graph; traced
  220. if self._outputs is unset:
  221. self._make_outputs(outputs)
  222. if self._outputs is None:
  223. if outputs is not None:
  224. raise TypeError("should return None")
  225. elif isinstance(self._outputs, Tensor):
  226. if not isinstance(outputs, Tensor):
  227. raise TypeError("should return a tensor")
  228. self._forward([outputs], [self._outputs])
  229. else:
  230. assert isinstance(self._outputs, tuple)
  231. def check():
  232. if not isinstance(outputs, (tuple, list)):
  233. return False
  234. if len(self._outputs) != len(outputs):
  235. return False
  236. for x in outputs:
  237. if not isinstance(x, Tensor):
  238. return False
  239. return True
  240. if not check():
  241. raise TypeError(
  242. "should return tuple of %d tensors" % len(self._outputs)
  243. )
  244. self._forward(outputs, self._outputs)
  245. def _apply_graph_options(self, cg):
  246. # graph opt level
  247. if self._graph_opt_level is not None:
  248. cg.set_option("graph_opt_level", self._graph_opt_level)
  249. # log level
  250. if self._log_level is not None:
  251. cg.set_option("log_level", self._log_level)
  252. # sublinear
  253. if self._sublinear_memory_config is not None:
  254. cg.set_option("enable_sublinear_memory_opt", True)
  255. cg.set_option(
  256. "sublinear_mem_cofig.lb_memory",
  257. self._sublinear_memory_config.lb_memory,
  258. )
  259. cg.set_option(
  260. "sublinear_mem_cofig.genetic_nr_iter",
  261. self._sublinear_memory_config.genetic_nr_iter,
  262. )
  263. cg.set_option(
  264. "sublinear_mem_cofig.genetic_pool_size",
  265. self._sublinear_memory_config.genetic_pool_size,
  266. )
  267. cg.set_option(
  268. "sublinear_mem_cofig.thresh_nr_try",
  269. self._sublinear_memory_config.thresh_nr_try,
  270. )
  271. cg.set_option(
  272. "sublinear_mem_cofig.num_worker",
  273. self._sublinear_memory_config.num_worker,
  274. )
  275. # profile
  276. if self._profiling:
  277. self._profiler = CompGraphProfiler(cg)
  278. def _get_graph(self, eager):
  279. if eager:
  280. if not hasattr(self, "_eager_graph"):
  281. # pylint: disable=attribute-defined-outside-init
  282. self._eager_graph = graph.Graph(eager_evaluation=True)
  283. self._apply_graph_options(self._eager_graph)
  284. return self._eager_graph
  285. else:
  286. if not hasattr(self, "_static_graph"):
  287. # pylint: disable=attribute-defined-outside-init
  288. self._static_graph = graph.Graph(eager_evaluation=False)
  289. self._apply_graph_options(self._static_graph)
  290. return self._static_graph
  291. @contextlib.contextmanager
  292. def _prepare(self, args, kwargs, enable):
  293. # prepare for execution
  294. self._forward_inputs(*args, **kwargs)
  295. if not enable:
  296. # XXX: use our own graph here?
  297. cg = None
  298. elif self._status == self._FINISHED:
  299. cg = None
  300. elif self._symbolic:
  301. cg = self._get_graph(eager=False)
  302. else:
  303. cg = self._get_graph(eager=True)
  304. try:
  305. # NOTE: always trace in a new graph, so capturing an undetached tensor
  306. # will never work (would work if tracing in default graph)
  307. if cg is None:
  308. yield
  309. else:
  310. with cg:
  311. yield
  312. finally:
  313. # XXX: properly release memory
  314. if cg:
  315. cg.clear_device_memory()
  316. @contextlib.contextmanager
  317. def _activate(self):
  318. # prepare for tracing
  319. if self._status != self._UNSTARTED:
  320. raise RuntimeError("cannot trace a second time")
  321. if type(self)._active_instance is not None:
  322. raise RuntimeError("nested trace is unsupported")
  323. self._status = self._STARTED
  324. type(self)._active_instance = self
  325. try:
  326. yield
  327. finally:
  328. self._status = self._FINISHED
  329. type(self)._active_instance = None
  330. def _run_wrapped(self):
  331. outputs = self.__wrapped__(*self._args, **self._kwargs)
  332. self._foward_outputs(outputs)
  333. return outputs
  334. def _do_trace(self):
  335. with self._activate():
  336. self._outspec = []
  337. outputs = self._run_wrapped()
  338. if outputs is None:
  339. self._sym_outputs = None
  340. else:
  341. if isinstance(outputs, Tensor):
  342. outputs = [outputs]
  343. # _run_wrapped has checked validity of outputs
  344. self._sym_outputs = tuple(i._symvar for i in outputs)
  345. self._compiled_func = graph.get_default_graph().compile(None, self._outspec)
  346. def trace(self, *args: Tensor, **kwargs):
  347. """
  348. Trace wrapped callable with provided arguments.
  349. """
  350. with self._prepare(args, kwargs, enable=True):
  351. self._do_trace()
  352. return self
  353. def __call__(self, *args: Tensor, **kwargs):
  354. """
  355. Evaluate on provided arguments, using compiled trace
  356. instead of the original callable if applicable.
  357. :return: ``None`` or :class:`~.Tensor` or tuple of :class:`~.Tensor`, depending on the
  358. return value of wrapped callable.
  359. """
  360. with self._prepare(args, kwargs, enable=self.enabled):
  361. if not self.enabled:
  362. self._run_wrapped()
  363. elif self._status == self._FINISHED:
  364. self._compiled_func()
  365. else:
  366. if self._status == self._UNSTARTED:
  367. self._do_trace()
  368. if self._symbolic:
  369. self._compiled_func()
  370. return self._outputs
  371. def dump(
  372. self,
  373. fpath,
  374. *,
  375. arg_names=None,
  376. append=False,
  377. optimize_for_inference=False,
  378. **kwargs
  379. ):
  380. """
  381. Serialize trace to file system.
  382. :param fpath: positional only argument. Path of output file.
  383. :param arg_names: names of the input tensors in the traced function
  384. :param append: whether output is appended to ``fpath``
  385. :param f16_io_f32_comp: whether to use float16 for I/O between oprs and use
  386. float32 as internal computation precision. Note the output var would be
  387. changed to float16
  388. :param f16_io_comp: whether to use float16 for both I/O and computation
  389. precision
  390. :param use_nhwcd4: whether to use NHWCD4 data format. This is faster on some
  391. OpenCL devices
  392. :param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  393. into one opr. This is supported only in NHWCD4 format.
  394. """
  395. if self._status != self._FINISHED:
  396. raise ValueError("not traced")
  397. assert isinstance(self._sym_outputs, (tuple, type(None)))
  398. if not self._sym_outputs:
  399. raise ValueError("not outputs")
  400. if arg_names is None:
  401. arg_names = ["arg_%d" % i for i in range(len(self._args))]
  402. elif len(arg_names) != len(self._args):
  403. raise ValueError(
  404. "len(arg_names) should be {}, got {}".format(
  405. len(self._args), len(arg_names)
  406. )
  407. )
  408. optimize_for_inference_args_map = {
  409. "enable_io16xc32": "f16_io_f32_comp",
  410. "enable_ioc16": "f16_io_comp",
  411. "enable_hwcd4": "use_nhwcd4",
  412. "enable_nchw88": "use_nchw88",
  413. "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity",
  414. "enable_tensorcore": "use_tensor_core",
  415. "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z",
  416. }
  417. if optimize_for_inference:
  418. optimize_for_inference_kwargs = {}
  419. for k, v in optimize_for_inference_args_map.items():
  420. if kwargs.pop(k, False):
  421. optimize_for_inference_kwargs[v] = True
  422. else:
  423. for k in optimize_for_inference_args_map:
  424. if kwargs.get(k, False):
  425. raise ValueError(
  426. "cannot set %s when optimize_for_inference is not set" % k
  427. )
  428. if kwargs:
  429. raise ValueError("unknown options: %s" % list(kwargs))
  430. cg = self._sym_outputs[0].owner_graph
  431. replace = {}
  432. for t, name in zip(self._args, arg_names):
  433. # relies on symvar dedup
  434. s = t.__mgb_symvar__(comp_graph=cg)
  435. replace[s] = mgb.make_arg(
  436. t.device, cg, dtype=t.dtype, shape=t.shape, name=name
  437. )
  438. # Convert VolatileSharedDeviceTensor to SharedDeviceTensor,
  439. # otherwise some optimizations would not work. The conversion is
  440. # safe because there simply is no way (using builtin ops) to make
  441. # a VolatileSharedDeviceTensor actually volatile.
  442. for s in mgb.cgtools.get_dep_vars(
  443. self._sym_outputs, "VolatileSharedDeviceTensor"
  444. ):
  445. if s in replace:
  446. continue # is an input
  447. replace[s] = mgb.SharedND._from_symvar(s).symvar(
  448. cg, name=s.name, volatile=False
  449. )
  450. sym_outputs = mgb.cgtools.replace_vars(self._sym_outputs, replace)
  451. sym_outputs = list(sym_outputs)
  452. if optimize_for_inference:
  453. sym_outputs = mgb.optimize_for_inference(
  454. sym_outputs, **optimize_for_inference_kwargs
  455. )
  456. mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append)
  457. def get_profile(self):
  458. """
  459. Get profiling result for compiled trace.
  460. :return: a json compatible object.
  461. """
  462. if not self._profiler:
  463. raise RuntimeError("trace is not set with profiling=True")
  464. return self._profiler.get()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台