You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import contextlib
  10. import functools
  11. import itertools
  12. import os
  13. from typing import Callable, Tuple, Union
  14. import numpy as np
  15. import megengine._internal as mgb
  16. from megengine._internal.plugin import CompGraphProfiler
  17. from ..core import Tensor, graph, tensor
  18. from .sublinear_memory_config import SublinearMemoryConfig
  19. def sideeffect(f):
  20. # during eager tracing, wrapped function is called with proxy inputs
  21. # during static tracing, wrapped function will not be called at all
  22. @functools.wraps(f)
  23. def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
  24. if not trace._active_instance:
  25. return f(*args, **kwargs)
  26. tensors = {}
  27. for i, x in itertools.chain(enumerate(args), kwargs.items()):
  28. if isinstance(x, Tensor):
  29. tensors[i] = x
  30. if tensors:
  31. _keys, tensors = zip(*tensors.items())
  32. else:
  33. _keys, tensors = (), ()
  34. def callback(*tensors, f=f, keys=_keys, args=args, kwargs=kwargs):
  35. replace = dict(zip(keys, tensors))
  36. args = tuple(replace.get(i, x) for i, x in enumerate(args))
  37. kwargs = {i: replace.get(i, x) for i, x in kwargs.items()}
  38. if f(*args, **kwargs) is not None:
  39. raise TypeError("a sideeffect function should return None")
  40. # TODO: clear memory
  41. trace._active_instance._register_callback(callback, tensors)
  42. return wrapper
  43. def mark_impure(x):
  44. if not trace._active_instance:
  45. return x
  46. return trace._active_instance._mark_impure(x)
  47. def barrier(x):
  48. if not trace._active_instance:
  49. return x
  50. return trace._active_instance._insert_barrier(x)
  51. def _dummy():
  52. return mgb.make_immutable(*graph._use_default_if_none(None, None), 0)
  53. class unset:
  54. pass
  55. class trace:
  56. """
  57. Wrap a callable and provide:
  58. * tracing via :meth:`.trace` and :meth:`.dump`
  59. * accelerated evalutaion via :meth:`.__call__`
  60. :param func: Positional only argument.
  61. :param symbolic: Whether to use symbolic tensor. Default: False
  62. :param opt_level: Optimization level for compiling trace.
  63. :param log_level: Log level.
  64. :param sublinear_memory_config: Configuration for sublinear memory optimization.
  65. If not None, it enables sublinear memory optimization with given setting.
  66. :param allreduce_pack_max_size: Maximum size of an allreduce pack in MB.
  67. If not None, multiple gradients will be packed and synchronized together
  68. :param profiling: Whether to profile compiled trace. Default: False
  69. """
  70. _active_instance = None
  71. enabled = not os.getenv("MGE_DISABLE_TRACE")
  72. _UNSTARTED = "unstarted"
  73. _STARTED = "started"
  74. _FINISHED = "finished"
  75. def __new__(cls, *args, **kwargs):
  76. if not args:
  77. return functools.partial(cls, **kwargs)
  78. return super().__new__(cls)
  79. def __init__(
  80. self,
  81. func: Callable[..., Union[None, Tensor, Tuple[Tensor]]],
  82. *,
  83. symbolic: bool = False,
  84. opt_level: int = None,
  85. log_level: int = None,
  86. sublinear_memory_config: SublinearMemoryConfig = None,
  87. allreduce_pack_max_size: int = None,
  88. profiling: bool = False
  89. ):
  90. self.__wrapped__ = func
  91. self._symbolic = symbolic
  92. self._graph_opt_level = opt_level
  93. self._log_level = log_level
  94. self._sublinear_memory_config = sublinear_memory_config
  95. self._allreduce_pack_max_size = allreduce_pack_max_size
  96. self._status = self._UNSTARTED
  97. self._args = None
  98. self._kwargs = None
  99. self._outputs = unset
  100. self._sym_outputs = unset
  101. self._outspec = None
  102. self._checkpoint = None
  103. self._compiled_func = None
  104. self._profiling = profiling
  105. self._profiler = None
  106. @property
  107. def _active(self):
  108. c1 = self._status == self._STARTED
  109. c2 = type(self)._active_instance is self
  110. assert c1 == c2
  111. return c1
  112. def _register_callback(self, f, args=()):
  113. assert self._active
  114. assert isinstance(args, (tuple, list))
  115. proxies = self._make_proxies(args)
  116. self._forward(args, proxies, checkpoint=True)
  117. # NOTE: under eager graph callback will fire immediately
  118. job = mgb.opr.callback_injector(
  119. self._insert_barrier(_dummy()), lambda _: f(*proxies)
  120. )
  121. self._insert_checkpoint(job)
  122. self._outspec.append(job)
  123. def _insert_barrier(self, x):
  124. assert self._active
  125. if self._checkpoint is None:
  126. return x
  127. if isinstance(x, Tensor):
  128. x = x._symvar
  129. wrap = True
  130. else:
  131. wrap = False
  132. if not isinstance(x, mgb.SymbolVar):
  133. raise TypeError
  134. x = mgb.opr.virtual_dep([x, self._checkpoint])
  135. if wrap:
  136. x = Tensor(x)
  137. return x
  138. def _insert_checkpoint(self, *args, no_barrier=False):
  139. assert self._active
  140. if not args:
  141. return
  142. args = tuple(x._symvar if isinstance(x, Tensor) else x for x in args)
  143. for x in args:
  144. if not isinstance(x, mgb.SymbolVar):
  145. raise TypeError
  146. if not no_barrier and self._checkpoint is not None:
  147. # normally no need to _insert_barrier here, but if
  148. # someone forget to call _insert_barrier beforehand,
  149. # this can make things less broken
  150. args += (self._checkpoint,)
  151. if len(args) == 1:
  152. self._checkpoint = args[0]
  153. else:
  154. self._checkpoint = mgb.opr.virtual_dep(args)
  155. def _mark_impure(self, x):
  156. assert self._active
  157. ret = x
  158. if isinstance(x, Tensor):
  159. x = x._symvar
  160. if not isinstance(x, mgb.SymbolVar):
  161. raise TypeError
  162. self._outspec.append(x)
  163. self._insert_checkpoint(x)
  164. return ret
  165. def _make_proxies(self, args):
  166. assert isinstance(args, (tuple, list))
  167. for x in args:
  168. assert isinstance(x, Tensor)
  169. return tuple(tensor(dtype=x.dtype, device=x.device) for x in args)
  170. def _forward(self, srcs, dests, checkpoint=True):
  171. # pseudo-op: does not run under static graph; traced
  172. # TODO: use shared memory
  173. assert len(srcs) == len(dests)
  174. if not self._active:
  175. for s, d in zip(srcs, dests):
  176. d.set_value(s, share=False)
  177. return
  178. jobs = []
  179. for s, d in zip(srcs, dests):
  180. def callback(value, dest=d):
  181. dest.set_value(value, share=False)
  182. s = self._insert_barrier(s._symvar)
  183. # NOTE: callback immediately fire in eager graph
  184. jobs.append(mgb.opr.callback_injector(s, callback))
  185. self._outspec.extend(jobs)
  186. if checkpoint:
  187. self._insert_checkpoint(*jobs, no_barrier=True)
  188. def _forward_inputs(self, *args, **kwargs):
  189. if self._kwargs is None:
  190. self._kwargs = kwargs
  191. elif self._kwargs != kwargs:
  192. raise ValueError("kwargs must not change between invocations")
  193. if self._args is None:
  194. self._args = []
  195. for i in args:
  196. if isinstance(i, Tensor):
  197. self._args.append(tensor(dtype=i.dtype, device=i.device))
  198. self._args[-1].set_value(i, share=False)
  199. else:
  200. self._args.append(tensor(i))
  201. else:
  202. if not len(args) == len(self._args):
  203. raise TypeError
  204. for i, proxy in zip(args, self._args):
  205. proxy.set_value(i, share=False)
  206. # XXX: sync?
  207. def _make_outputs(self, outputs):
  208. if outputs is None:
  209. self._outputs = None
  210. return
  211. if isinstance(outputs, Tensor):
  212. # no one is able to call barrier after this, so no need to checkpoint
  213. # but checkpoint do little harm anyway
  214. (self._outputs,) = self._make_proxies([outputs])
  215. return
  216. if not isinstance(outputs, (tuple, list)):
  217. raise TypeError("should return (tuple of) tensor")
  218. for i in outputs:
  219. if not isinstance(i, Tensor):
  220. raise TypeError("should return (tuple of) tensor")
  221. self._outputs = self._make_proxies(outputs)
  222. def _foward_outputs(self, outputs):
  223. # pseudo-op: does not run under static graph; traced
  224. if self._outputs is unset:
  225. self._make_outputs(outputs)
  226. if self._outputs is None:
  227. if outputs is not None:
  228. raise TypeError("should return None")
  229. elif isinstance(self._outputs, Tensor):
  230. if not isinstance(outputs, Tensor):
  231. raise TypeError("should return a tensor")
  232. self._forward([outputs], [self._outputs])
  233. else:
  234. assert isinstance(self._outputs, tuple)
  235. def check():
  236. if not isinstance(outputs, (tuple, list)):
  237. return False
  238. if len(self._outputs) != len(outputs):
  239. return False
  240. for x in outputs:
  241. if not isinstance(x, Tensor):
  242. return False
  243. return True
  244. if not check():
  245. raise TypeError(
  246. "should return tuple of %d tensors" % len(self._outputs)
  247. )
  248. self._forward(outputs, self._outputs)
  249. def _apply_graph_options(self, cg):
  250. # graph opt level
  251. if self._graph_opt_level is not None:
  252. cg.set_option("graph_opt_level", self._graph_opt_level)
  253. # log level
  254. if self._log_level is not None:
  255. cg.set_option("log_level", self._log_level)
  256. # sublinear
  257. if self._sublinear_memory_config is not None:
  258. cg.set_option("enable_sublinear_memory_opt", True)
  259. cg.set_option(
  260. "sublinear_mem_cofig.lb_memory",
  261. self._sublinear_memory_config.lb_memory,
  262. )
  263. cg.set_option(
  264. "sublinear_mem_cofig.genetic_nr_iter",
  265. self._sublinear_memory_config.genetic_nr_iter,
  266. )
  267. cg.set_option(
  268. "sublinear_mem_cofig.genetic_pool_size",
  269. self._sublinear_memory_config.genetic_pool_size,
  270. )
  271. cg.set_option(
  272. "sublinear_mem_cofig.thresh_nr_try",
  273. self._sublinear_memory_config.thresh_nr_try,
  274. )
  275. cg.set_option(
  276. "sublinear_mem_cofig.num_worker",
  277. self._sublinear_memory_config.num_worker,
  278. )
  279. # pack allreduce
  280. if self._allreduce_pack_max_size is not None:
  281. cg.set_option("allreduce_pack_max_size", self._allreduce_pack_max_size)
  282. # profile
  283. if self._profiling:
  284. self._profiler = CompGraphProfiler(cg)
  285. def _get_graph(self, eager):
  286. if eager:
  287. if not hasattr(self, "_eager_graph"):
  288. # pylint: disable=attribute-defined-outside-init
  289. self._eager_graph = graph.Graph(eager_evaluation=True)
  290. self._apply_graph_options(self._eager_graph)
  291. return self._eager_graph
  292. else:
  293. if not hasattr(self, "_static_graph"):
  294. # pylint: disable=attribute-defined-outside-init
  295. self._static_graph = graph.Graph(eager_evaluation=False)
  296. self._apply_graph_options(self._static_graph)
  297. return self._static_graph
  298. @contextlib.contextmanager
  299. def _prepare(self, args, kwargs, enable):
  300. # prepare for execution
  301. self._forward_inputs(*args, **kwargs)
  302. if not enable:
  303. # XXX: use our own graph here?
  304. cg = None
  305. elif self._status == self._FINISHED:
  306. cg = None
  307. elif self._symbolic:
  308. cg = self._get_graph(eager=False)
  309. else:
  310. cg = self._get_graph(eager=True)
  311. try:
  312. # NOTE: always trace in a new graph, so capturing an undetached tensor
  313. # will never work (would work if tracing in default graph)
  314. if cg is None:
  315. yield
  316. else:
  317. with cg:
  318. yield
  319. finally:
  320. # XXX: properly release memory
  321. if cg:
  322. cg.clear_device_memory()
  323. @contextlib.contextmanager
  324. def _activate(self):
  325. # prepare for tracing
  326. if self._status != self._UNSTARTED:
  327. raise RuntimeError("cannot trace a second time")
  328. if type(self)._active_instance is not None:
  329. raise RuntimeError("nested trace is unsupported")
  330. self._status = self._STARTED
  331. type(self)._active_instance = self
  332. self._user_cache = {}
  333. try:
  334. yield
  335. finally:
  336. self._status = self._FINISHED
  337. self._user_cache = None
  338. type(self)._active_instance = None
  339. def _run_wrapped(self):
  340. outputs = self.__wrapped__(*self._args, **self._kwargs)
  341. self._foward_outputs(outputs)
  342. return outputs
  343. def _do_trace(self):
  344. with self._activate():
  345. self._outspec = []
  346. outputs = self._run_wrapped()
  347. if outputs is None:
  348. self._sym_outputs = None
  349. else:
  350. if isinstance(outputs, Tensor):
  351. outputs = [outputs]
  352. # _run_wrapped has checked validity of outputs
  353. self._sym_outputs = tuple(i._symvar for i in outputs)
  354. mgb.comp_graph_tools.set_priority_to_id(self._outspec)
  355. self._compiled_func = graph.get_default_graph().compile(None, self._outspec)
  356. def trace(self, *args: Tensor, **kwargs):
  357. """
  358. Trace wrapped callable with provided arguments.
  359. """
  360. with self._prepare(args, kwargs, enable=True):
  361. self._do_trace()
  362. return self
  363. def __call__(self, *args: Tensor, **kwargs):
  364. """
  365. Evaluate on provided arguments, using compiled trace
  366. instead of the original callable if applicable.
  367. :return: ``None`` or :class:`~.Tensor` or tuple of :class:`~.Tensor`, depending on the
  368. return value of wrapped callable.
  369. """
  370. with self._prepare(args, kwargs, enable=self.enabled):
  371. if not self.enabled:
  372. self._run_wrapped()
  373. elif self._status == self._FINISHED:
  374. self._compiled_func()
  375. else:
  376. if self._status == self._UNSTARTED:
  377. self._do_trace()
  378. if self._symbolic:
  379. self._compiled_func()
  380. return self._outputs
  381. def dump(
  382. self,
  383. fpath,
  384. *,
  385. arg_names=None,
  386. append=False,
  387. optimize_for_inference=False,
  388. **kwargs
  389. ):
  390. """
  391. Serialize trace to file system.
  392. :param fpath: positional only argument. Path of output file.
  393. :param arg_names: names of the input tensors in the traced function.
  394. :param append: whether output is appended to ``fpath``.
  395. :param optimize_for_inference: whether to enable optimize_for_inference
  396. pass before dump.
  397. :param enable_io16xc32: whether to use float16 for I/O between oprs and use
  398. float32 as internal computation precision. Note the output var would be
  399. changed to float16.
  400. :param enable_ioc16: whether to use float16 for both I/O and computation
  401. precision.
  402. :param enable_hwcd4: whether to use NHWCD4 data layout. This is faster on some
  403. OpenCL backend.
  404. :param enable_nchw88: whether to use NCHW4 data layout. it currently
  405. used in X86 AVX backend.
  406. :param enable_nchw44: whether to use NCHW4 data layout. it currently
  407. used in arm backend.
  408. :param enable_nchw44_dot: whether to use NCHW4 data layout. it currently
  409. used in armv8.2+dotprod backend.
  410. :param enable_nchw4: whether to use NCHW4 data layout. it currently
  411. used in nvidia backend(based on cudnn).
  412. :param enable_nchw32 whether to use NCHW32 data layout. it currently
  413. used in nvidia backend with tensorcore(based on cudnn).
  414. :param enable_chwn4 whether to use CHWN4 data layout. it currently
  415. used in nvidia backend with tensorcore.
  416. :param enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  417. into one opr.
  418. :param enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  419. input for inference on nvidia backend(this optimization pass will
  420. result in mismatch of the precision of output of training and
  421. inference)
  422. """
  423. if self._status != self._FINISHED:
  424. raise ValueError("not traced")
  425. assert isinstance(self._sym_outputs, (tuple, type(None)))
  426. if not self._sym_outputs:
  427. raise ValueError("not outputs")
  428. if arg_names is None:
  429. arg_names = ["arg_%d" % i for i in range(len(self._args))]
  430. elif len(arg_names) != len(self._args):
  431. raise ValueError(
  432. "len(arg_names) should be {}, got {}".format(
  433. len(self._args), len(arg_names)
  434. )
  435. )
  436. optimize_for_inference_args_map = {
  437. "enable_io16xc32": "f16_io_f32_comp",
  438. "enable_ioc16": "f16_io_comp",
  439. "enable_hwcd4": "use_nhwcd4",
  440. "enable_nchw4": "use_nchw4",
  441. "enable_nchw88": "use_nchw88",
  442. "enable_nchw32": "use_nchw32",
  443. "enable_nchw44": "use_nchw44",
  444. "enable_nchw44_dot": "use_nchw44_dot",
  445. "enable_chwn4": "use_chwn4",
  446. "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity",
  447. "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z",
  448. }
  449. if optimize_for_inference:
  450. optimize_for_inference_kwargs = {}
  451. for k, v in optimize_for_inference_args_map.items():
  452. if kwargs.pop(k, False):
  453. optimize_for_inference_kwargs[v] = True
  454. else:
  455. for k in optimize_for_inference_args_map:
  456. if kwargs.get(k, False):
  457. raise ValueError(
  458. "cannot set %s when optimize_for_inference is not set" % k
  459. )
  460. if kwargs:
  461. raise ValueError("unknown options: %s" % list(kwargs))
  462. cg = self._sym_outputs[0].owner_graph
  463. replace = {}
  464. for t, name in zip(self._args, arg_names):
  465. # relies on symvar dedup
  466. s = t.__mgb_symvar__(comp_graph=cg)
  467. replace[s] = mgb.make_arg(
  468. t.device, cg, dtype=t.dtype, shape=t.shape, name=name
  469. )
  470. # Convert VolatileSharedDeviceTensor to SharedDeviceTensor,
  471. # otherwise some optimizations would not work. The conversion is
  472. # safe because there simply is no way (using builtin ops) to make
  473. # a VolatileSharedDeviceTensor actually volatile.
  474. for s in mgb.cgtools.get_dep_vars(
  475. self._sym_outputs, "VolatileSharedDeviceTensor"
  476. ):
  477. if s in replace:
  478. continue # is an input
  479. replace[s] = mgb.SharedND._from_symvar(s).symvar(
  480. cg, name=s.name, volatile=False
  481. )
  482. sym_outputs = mgb.cgtools.replace_vars(self._sym_outputs, replace)
  483. sym_outputs = list(sym_outputs)
  484. if optimize_for_inference:
  485. sym_outputs = mgb.optimize_for_inference(
  486. sym_outputs, **optimize_for_inference_kwargs
  487. )
  488. mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append)
  489. def get_profile(self):
  490. """
  491. Get profiling result for compiled trace.
  492. :return: a json compatible object.
  493. """
  494. if not self._profiler:
  495. raise RuntimeError("trace is not set with profiling=True")
  496. return self._profiler.get()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台