You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. """the megbrain python package
  10. Note that all the submodules are automatically imported, so you usually only
  11. need to ``import megengine._internal as mgb``.
  12. """
  13. import collections
  14. import json
  15. import os
  16. import sys
  17. import numpy as np
  18. from . import comp_graph_tools as cgtools
  19. from . import config, craniotome, dtype
  20. from . import global_init as _global_init
  21. from . import helper as _helper
  22. from . import mgb as _detail
  23. from . import opr, opr_extra, opr_param_defs, plugin
  24. from .exc import MegBrainError
  25. from .logconf import get_logger
  26. from .mgb import (
  27. CompGraph,
  28. CompNode,
  29. SharedND,
  30. SharedScalar,
  31. SymbolVar,
  32. TensorValueDumperContext,
  33. TensorValueLoaderContext,
  34. )
  35. from .mgb import as_comp_node as comp_node
  36. from .mgb_helper import SharedNDLazyInitializer, callback_lazycopy, copy_output
  37. from .plugin import CompGraphProfiler
  38. from .plugin import GlobalInfkernFinder as _GlobalInfkernFinder
  39. from .plugin import NumRangeChecker
  40. from .version import __version__, version_info
  41. if sys.version_info.major < 3:
  42. raise ImportError("megbrain requires python 3")
  43. class ProxySharedNDAndSymbolVar(_detail.SymbolVar):
  44. """this is a :class:`.SymbolVar` with a corresponding :class:`.SharedND`.
  45. It can participate in graph computating and also provides :meth:`set_value`
  46. and :meth:`get_value`. It should be constructed by :func:`make_shared`.
  47. """
  48. __shared_nd = None
  49. __kwargs = None
  50. def __init__(self, snd, comp_graph, name, **kwargs):
  51. self.__shared_nd = snd
  52. self.__kwargs = kwargs
  53. self.this = snd.symvar(comp_graph=comp_graph, name=name, **kwargs).this
  54. def set_value(self, v, **kwargs):
  55. ret = self.__shared_nd.set_value(v, **kwargs)
  56. self._reeval_if_eager_eval()
  57. return ret
  58. def get_value(self):
  59. return self.__shared_nd.get_value()
  60. def reset_zero(self):
  61. self.__shared_nd.reset_zero()
  62. def make_shared(
  63. comp_node,
  64. *,
  65. dtype=None,
  66. shape=None,
  67. value=None,
  68. comp_graph=None,
  69. name=None,
  70. volatile=None
  71. ):
  72. """make a shared tensor which is stored on device and could be modified
  73. later, either as a :class:`.SymbolVar` or a :class:`.SharedND` object
  74. :param comp_node: computing node
  75. :type comp_node: :class:`.CompNode`
  76. :param dtype: data type; if it is None, then dtype of value would be used
  77. if value is not None, and float32 would be used as default dtype if
  78. value is None
  79. :type dtype: :class:`numpy.dtype` compatible
  80. :param value: initializing value
  81. :type value: None or :class:`numpy.ndarray`
  82. :param comp_graph: the computing graph to which this shared value should
  83. belong; if provided, the retuned object could be used as a
  84. :class:`.SymbolVar`
  85. :type comp_graph: None or :class:`.CompGraph`
  86. :param name: node name to be used in computing graph; only meaningful if
  87. *comp_graph* is provided
  88. :param volatile: if *comp_graph* is given then *volatile* indicates whether
  89. shape or mem ptr of this SharedND can be changed
  90. :rtype: :class:`.SharedND` if *comp_graph* is not given; or
  91. :class:`ProxySharedNDAndSymbolVar` otherwise
  92. """
  93. if dtype is None:
  94. if value is not None:
  95. value = np.ascontiguousarray(value)
  96. dtype = to_mgb_supported_dtype(value.dtype)
  97. else:
  98. dtype = np.float32
  99. comp_node = _detail.as_comp_node(comp_node)
  100. rst = _detail.SharedND(comp_node, dtype)
  101. if value is not None:
  102. assert shape is None, "could not provide both value and shape"
  103. rst.set_value(value)
  104. elif shape is not None:
  105. rst._set_init_shape(shape)
  106. if comp_graph is None:
  107. assert name is None and volatile is None
  108. return rst
  109. assert isinstance(comp_graph, CompGraph), "expect CompGraph but got {}".format(
  110. comp_graph
  111. )
  112. if volatile is None:
  113. volatile = False
  114. else:
  115. assert isinstance(volatile, bool)
  116. return ProxySharedNDAndSymbolVar(rst, comp_graph, name, volatile=volatile)
  117. def make_immutable(comp_node, comp_graph, value, *, dtype=None, name=None):
  118. """make a graph node containing an immutable tensor from host tensor value
  119. :param dtype: required data type; if not None, the data would be converted
  120. to that type; otherwise
  121. """
  122. comp_node = _detail.as_comp_node(comp_node)
  123. assert isinstance(
  124. comp_graph, _detail.CompGraph
  125. ), "expect CompGraph but got {!r}".format(comp_graph)
  126. config = _detail.make_opr_config(name, comp_node)
  127. return _helper.cvt_opr_result(
  128. _detail._make_immutable(comp_graph, value, dtype, config)
  129. )
  130. def make_arg(
  131. comp_node,
  132. comp_graph,
  133. *,
  134. dtype=np.float32,
  135. shape=None,
  136. name=None,
  137. value=None,
  138. enable_static_infer=True
  139. ):
  140. """make an argument to be passed to compiled function during runtime;
  141. :type shape: None or tuple of int
  142. :param shape: expected tensor shape to be used for shape inferring; actual
  143. tesor shape could be different
  144. :type name: str
  145. :param name: name of the generated var node
  146. :type value: None or ndarray-compatible
  147. :param value: initial value used for static inference; if not given, static
  148. infer would be deferred to first graph execution
  149. :param enable_static_infer: whether to enable static inference for this var
  150. """
  151. comp_node = _detail.as_comp_node(comp_node)
  152. host_val = mgb._HostSharedND(comp_node, dtype)
  153. if value is not None:
  154. value = np.ascontiguousarray(value, dtype=dtype)
  155. if shape is None:
  156. shape = value.shape
  157. else:
  158. assert shape == value.shape
  159. if shape is not None:
  160. host_val._resize(shape)
  161. if value is not None:
  162. host_val.set_value(value)
  163. return _helper.cvt_opr_result(
  164. ProxySharedNDAndSymbolVar(
  165. host_val, comp_graph, name, enable_static_infer=enable_static_infer
  166. )
  167. )
  168. def comp_graph(*, extra_opts=None, check_env_var=True):
  169. """allocate a new computing graph
  170. :param extra_opts: extra options to be set; would be updated (modified
  171. inplace) from ``MGB_COMP_GRAPH_OPT`` environment var. See
  172. :func:`.set_comp_graph_option` for list of supported options.
  173. :type extra_opts: dict
  174. :param check_env_var: whether to check environment vars
  175. :type check_env_var: bool
  176. :return: the comp graph object
  177. :rtype: :class:`.CompGraph`
  178. """
  179. cg = _detail.CompGraph()
  180. if extra_opts is None:
  181. extra_opts = {}
  182. if check_env_var:
  183. setting = os.getenv("MGB_COMP_GRAPH_OPT")
  184. if setting:
  185. for item in setting.split(";"):
  186. k, v = item.split("=", 1)
  187. extra_opts.setdefault(k, v)
  188. get_logger().warning(
  189. "set comp graph option from env: {}".format(extra_opts)
  190. )
  191. user_data = os.getenv("MGB_COMP_GRAPH_USER_DATA")
  192. if user_data:
  193. storage = cg.user_data
  194. for ud in user_data.split(";"):
  195. k, v = ud.split("=", 1)
  196. storage[k] = eval(v)
  197. _GlobalInfkernFinder.add_graph(cg)
  198. for k, v in extra_opts.items():
  199. cg.set_option(k, v)
  200. return cg
  201. def grad(
  202. target, wrt, warn_mid_wrt=True, use_virtual_grad=None, return_zero_for_nodep=True
  203. ):
  204. r"""compute symbolic grad
  205. :param target: grad target var
  206. :type target: :class:`.SymbolVar`
  207. :param wrt: with respect to which to compute the grad
  208. :type wrt: :class:`.SymbolVar` or Iterable[SymbolVar]
  209. :param warn_mid_wrt: whether to give warning if *wrt* is not endpoint
  210. :type warn_mid_wrt: bool
  211. :param use_virtual_grad: whether to use virtual grad opr, so fwd graph can
  212. be optimized before applying grad; if ``None`` is given, then virtual
  213. grad would be used if ``graph_opt_level >= 2``
  214. :type use_virtual_grad: :class:`bool` or ``None``
  215. :param return_zero_for_nodep: if *target* does not depend on *wrt*, set to True to return
  216. a zero-valued `.SymbolVar` rather than ``None``; can't be set to False when using
  217. virtual grad opr.
  218. :type return_zero_for_nodep: bool
  219. :rtype: :class:`.SymbolVar` or None
  220. :return: :math:`\frac{\partial\text{target}}{\partial\text{wrt}}`
  221. """
  222. if use_virtual_grad is None:
  223. use_virtual_grad = -1
  224. else:
  225. use_virtual_grad = 1 if use_virtual_grad else 0
  226. if isinstance(wrt, SymbolVar):
  227. wrts = [
  228. wrt,
  229. ]
  230. else:
  231. wrts = wrt
  232. assert isinstance(wrts, collections.Iterable)
  233. # return a invalid SymbolVar (with nullptr VarNode*) when return_zero_for_nodep is False
  234. # and target doesn't depend on wrt
  235. grads = _detail._grad(
  236. target, wrts, bool(warn_mid_wrt), use_virtual_grad, return_zero_for_nodep
  237. )
  238. grads = list(grads)
  239. for i in range(len(grads)):
  240. if not grads[i].valid:
  241. assert (
  242. not return_zero_for_nodep
  243. ), "invalid grad SymbolVar: target={}, wrt={}".format(target, wrts[i])
  244. grads[i] = None
  245. if len(grads) == 1:
  246. grads = grads[0]
  247. return grads
  248. def current_grad_target(comp_graph):
  249. """get current target var to compute grad, used for implementing custom
  250. gradient"""
  251. return _detail._current_grad_target(comp_graph)
  252. def add_device_map(map_location):
  253. """add map location while loading models"""
  254. _detail.CompNode.cn_thread_local.__setattr__("map_location", map_location)
  255. def del_device_map():
  256. """delete map location"""
  257. _detail.CompNode.cn_thread_local.__delattr__("map_location")
  258. def inter_graph_trans_var(dest_graph, src):
  259. """get the corresponding var of *src* in *dest_graph*; assuming
  260. *dest_graph* is a copy of owner graph of *src*; usually used in callback of
  261. set_grad to get grad of vars in loop
  262. :param dest_graph: target computing graph
  263. :type dest_graph: :class:`.CompGraph`
  264. :param src: source var node
  265. :type src: :class:`.SymbolVar`
  266. :return: corresponding var in *dest_graph*
  267. :rtype: :class:`.SymbolVar`
  268. """
  269. return _detail._inter_graph_trans_var(dest_graph, src)
  270. def get_graph_optimizer_replaced_var(src):
  271. """get optimized var corresponding to given var; usually used in callback
  272. of set_grad to get grad w.r.t. some var
  273. :param src: source var node
  274. :type src: :class:`.SymbolVar`
  275. :rtype: :class:`.SymbolVar`
  276. """
  277. return _detail._get_graph_optimizer_replaced_var(src)
  278. CompGraphSerializationResult = collections.namedtuple(
  279. "CompGraphSerializationResult",
  280. [
  281. "nr_opr",
  282. "tot_bytes",
  283. "tensor_value_bytes",
  284. "content_hash",
  285. "inputs",
  286. "outputs",
  287. "params",
  288. ],
  289. )
  290. def serialize_comp_graph_to_file(
  291. fpath,
  292. output_vars,
  293. *,
  294. keep_var_name=1,
  295. keep_param_name=False,
  296. keep_opr_priority=False,
  297. tensor_value_dumper=None,
  298. output_strip_info=False,
  299. append=False,
  300. format=None,
  301. **kwargs
  302. ):
  303. """serialize this computing graph and write result to a file. Note:
  304. ``kwargs`` exists for backward compatibility; there is no additional
  305. arguments.
  306. :parma fpath: path for the output file
  307. :type fpath: ``str``
  308. :param output_vars: output variables that need to be retrieved when
  309. deserializing
  310. .. note::
  311. The underlying C++ API only accepts a var list. If a dict is given,
  312. the vars would be renamed to given names.
  313. :type output_vars: dict(name => :class:`.SymbolVar`), or a list of vars
  314. :param keep_var_name: level for keeping variable names:
  315. * 0: none of the names are kept
  316. * 1: keep names of output vars
  317. * 2: keep names of all (output and internal) vars
  318. :param keep_param_name: whether to keep param names, so param values can be
  319. easily manipulated after loading model
  320. :param keep_opr_priority: whether to keep priority setting for operators
  321. :param tensor_value_dumper: a callable to dump tensor values; it should
  322. only write the tensor value without layout information. It would be
  323. given a :class:`.TensorValueDumperContext` object as its sole argument.
  324. :param output_strip_info: if set to True, then a json file containing
  325. information for code strip would be written to ``fpath+'.json'``
  326. :param append: whether to open output file in append mode
  327. :return: an instance of namedtuple :class:`CompGraphSerializationResult`,
  328. whose fields are:
  329. * ``nr_opr`` number of operators dumped
  330. * ``tot_bytes`` total bytes for the whole graph
  331. * ``tensor_value_bytes`` bytes consumed for dumping tensor values
  332. * ``inputs`` names of input tensors
  333. * ``params`` list of names of dumped params
  334. * ``outputs`` names of output vars
  335. :param format: serialization format of the resulting model, should be either
  336. "mdl" or "fbs"; none means default.
  337. :type format: ``str``
  338. """
  339. assert isinstance(fpath, str), "bad file path: {!r}".format(fpath)
  340. ov = _detail._VectorSymbolVar()
  341. SUPPORTED_FORMATS = {
  342. # default
  343. None: _detail.GraphDumpFormat_FLATBUFFERS,
  344. "fbs": _detail.GraphDumpFormat_FLATBUFFERS,
  345. }
  346. resolved_fmt = SUPPORTED_FORMATS.get(format, None)
  347. if resolved_fmt is None:
  348. raise ValueError(
  349. "unknown format {} requested, supported ones are {}".format(
  350. format, list(filter(None, SUPPORTED_FORMATS.keys()))
  351. )
  352. )
  353. if isinstance(output_vars, dict):
  354. used_vars = set()
  355. for name, var in output_vars.items():
  356. assert isinstance(var, _detail.SymbolVar), "bad output var: {!r}".format(
  357. var
  358. )
  359. assert var.id not in used_vars, (
  360. "var name is associated with a var object, so we can not have "
  361. "two names given to the same var: {}".format(var)
  362. )
  363. used_vars.add(var.id)
  364. var.rename(name)
  365. ov.push_back(var)
  366. else:
  367. for i in output_vars:
  368. assert isinstance(i, _detail.SymbolVar), "bad output var: {!r}".format(i)
  369. ov.push_back(i)
  370. if tensor_value_dumper is not None:
  371. assert isinstance(tensor_value_dumper, collections.Callable)
  372. class Callback(_detail._TensorValueDumperCallback):
  373. def call(self, ctx, *, _f=tensor_value_dumper):
  374. _f(ctx)
  375. tensor_value_dumper = Callback()
  376. # for backward compatibility
  377. mangle_opr_name = kwargs.pop("mangle_opr_name", ov)
  378. if mangle_opr_name is not ov:
  379. get_logger().warning("mangle_opr_name is deprecated; use keep_var_name instead")
  380. keep_var_name = 1 if mangle_opr_name else 2
  381. mangle_param_name = kwargs.pop("mangle_param_name", ov)
  382. assert (
  383. not kwargs
  384. ), "extra kwargs provided to serialize_comp_graph_to_file: {}".format(kwargs)
  385. if mangle_param_name is not ov:
  386. get_logger().warning(
  387. "mangle_param_name is deprecated; use keep_param_name instead"
  388. )
  389. keep_param_name = not mangle_param_name
  390. inputs = _detail._VectorString()
  391. outputs = _detail._VectorString()
  392. params = _detail._VectorString()
  393. stat = _detail._VectorSizeT()
  394. _detail._serialize_comp_graph_to_file(
  395. fpath,
  396. append,
  397. resolved_fmt,
  398. ov,
  399. keep_var_name,
  400. keep_param_name,
  401. keep_opr_priority,
  402. tensor_value_dumper,
  403. stat,
  404. inputs,
  405. outputs,
  406. params,
  407. )
  408. dump_ret = CompGraphSerializationResult(
  409. *stat, list(inputs), list(outputs), list(params)
  410. )
  411. if output_strip_info:
  412. with open(fpath + ".json", "w") as fout:
  413. strip_info = _detail._get_info_for_strip(ov)
  414. strip_info_dict = json.loads(strip_info)
  415. strip_info_dict["hash"] = dump_ret.content_hash
  416. json.dump(strip_info_dict, fout)
  417. return dump_ret
  418. CompGraphLoadResult = collections.namedtuple(
  419. "CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"]
  420. )
  421. def load_comp_graph_from_file(
  422. fpath, *, comp_node_mapper=None, tensor_value_loader=None
  423. ):
  424. """Load a serialized computing graph from file.
  425. :parma fpath: Path for the output file
  426. :type fpath: ``str``
  427. :param comp_node_mapper: A callable to modify comp node locator, takes old
  428. locator as argument and returns new locator.
  429. :type comp_node_mapper: Callable[[str], str]
  430. :param tensor_value_loader: A callable to load tensor values. It should
  431. read the tensor value with the given shape and dtype and return it as
  432. NumPy ndarray. It would be given a :class:`.TensorValueLoaderContext`
  433. object as its sole argument.
  434. :type tensor_value_loader: Callable[[TensorValueLoaderContext], numpy.ndarray]
  435. :return: An instance of namedtuple :class:`CompGraphLoadResult`,
  436. whose fields are:
  437. * ``graph`` loaded CompGraph
  438. * ``output_vars_dict`` A Python dict, mapping name to output SymbolVar
  439. * ``output_vars_list`` A Python list, containing output vars in the
  440. order passed to serialize_comp_graph_to_file
  441. """
  442. assert isinstance(fpath, str), "bad file path: {!r}".format(fpath)
  443. if comp_node_mapper is not None:
  444. assert isinstance(comp_node_mapper, collections.Callable)
  445. class Callback(_detail._CompNodeMapperCallback):
  446. def call(self, desc, *, _f=comp_node_mapper):
  447. return _f(desc)
  448. comp_node_mapper = Callback()
  449. if tensor_value_loader is not None:
  450. assert isinstance(tensor_value_loader, collections.Callable)
  451. class Callback(_detail._TensorValueLoaderCallback):
  452. def call(self, ctx, *, _f=tensor_value_loader):
  453. return _f(ctx)
  454. tensor_value_loader = Callback()
  455. output_vars_map = _detail._VectorPairStringSymbolVar()
  456. output_vars_list = _detail._VectorSymbolVar()
  457. cg = _detail._load_comp_graph_from_file(
  458. fpath, comp_node_mapper, tensor_value_loader, output_vars_map, output_vars_list
  459. )
  460. return CompGraphLoadResult(cg, dict(list(output_vars_map)), list(output_vars_list))
  461. def optimize_for_inference(
  462. output_vars,
  463. *,
  464. f16_io_f32_comp=False,
  465. f16_io_comp=False,
  466. use_nhwcd4=False,
  467. fuse_conv_bias_nonlinearity=False,
  468. use_nchw32=False,
  469. fuse_conv_bias_with_z=False,
  470. use_nchw4=False,
  471. use_nchw88=False,
  472. use_nchw44=False,
  473. use_nchw44_dot=False,
  474. use_chwn4=False
  475. ):
  476. """optimize computing graph for inference
  477. This applies a predefined set of optimization passes. Refer to the mnist
  478. sdk example and C++ code for fine-grained control.
  479. :param output_vars: output symvars
  480. :type output_vars: list of :class:`.SymbolVar`
  481. :param f16_io_f32_comp: whether to use float16 for I/O between oprs and use
  482. float32 as internal computation precision. Note the output var would be
  483. changed to float16
  484. :param f16_io_comp: whether to use float16 for both I/O and computation
  485. precision
  486. :param use_nhwcd4: whether to use NHWCD4 data format. This is faster on some
  487. OpenCL devices
  488. :param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  489. into one opr. This is supported only in NHWCD4 format.
  490. :param use_nchw4: whether to use NCHW4 tensor format.
  491. :param use_nchw88: whether to use NCHW88 tensor format. This maybe faster some
  492. times.
  493. :param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some
  494. times.
  495. :param use_nchw44_dot: whether to use NCHW44_DOT tensor format. This format is
  496. optimized for inference in armv8.2
  497. :param use_nchw32: whether to use NCHW32 tensor format. Mainly used for
  498. nvidia tensorcore.
  499. :param use_chwn4: whether to use CHWN4 tensor format. Mainly used for
  500. nvidia tensorcore.
  501. :return: list of transformed vars corresponding to given output vars
  502. """
  503. assert isinstance(output_vars, (list, tuple))
  504. opt = _detail._OptimizeForInferenceOptions()
  505. settings = locals()
  506. for i in [
  507. "f16_io_f32_comp",
  508. "f16_io_comp",
  509. "fuse_conv_bias_nonlinearity",
  510. "fuse_conv_bias_with_z",
  511. ]:
  512. if settings[i]:
  513. getattr(opt, "enable_{}".format(i))()
  514. layout_tranform = None
  515. for k, v in {
  516. "use_nchw4": "nchw4",
  517. "use_nhwcd4": "nhwcd4",
  518. "use_nchw32": "nchw32",
  519. "use_nchw88": "nchw88",
  520. "use_nchw44": "nchw44",
  521. "use_nchw44_dot": "nchw44_dot",
  522. "use_chwn4": "chwn4",
  523. }.items():
  524. if settings[k]:
  525. assert (
  526. not layout_tranform
  527. ), "Only one layout transform supported, both {} and {}".format(
  528. layout_tranform, k
  529. )
  530. getattr(opt, "enable_{}".format(v))()
  531. layout_tranform = k
  532. vec = _detail._VectorSymbolVar()
  533. for i in output_vars:
  534. assert isinstance(i, _detail.SymbolVar), "bad var: {}".format(i)
  535. vec.push_back(i)
  536. return list(_detail._optimize_for_inference(vec, opt))
  537. def get_opr_fp_graph_exec(comp_graph, output_vars):
  538. """get opr footprint and graph exec info
  539. This function will recompile the compute graph, the AsyncExecutable compiled
  540. before will be invalid.
  541. :param comp_graph: ComputingGraph
  542. :param output_vars: list of :class:'.SymbolVar'
  543. """
  544. assert isinstance(output_vars, (list, tuple))
  545. vec = _detail._VectorSymbolVar()
  546. for i in output_vars:
  547. assert isinstance(i, _detail.SymbolVar), "bad var: {}".format(i)
  548. vec.push_back(i)
  549. return json.loads(_detail._get_opr_fp_graph_exec(comp_graph, output_vars))
  550. def to_mgb_supported_dtype(dtype_):
  551. """get the dtype supported by megbrain nearest to given dtype"""
  552. if (
  553. dtype.is_lowbit(dtype_)
  554. or dtype.is_quantize(dtype_)
  555. or dtype.is_bfloat16(dtype_)
  556. ):
  557. return dtype_
  558. return _detail._to_mgb_supported_dtype(dtype_)
  559. def return_free_memory():
  560. """return free memory chunks on all devices.
  561. This function will try it best to free all consecutive free chunks back to
  562. operating system, small pieces may not be returned.
  563. Please notice that this function will not move any memory in-use.
  564. """
  565. _detail.CompNode._try_coalesce_all_free_memory()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台