You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. """the megbrain python package
  10. Note that all the submodules are automatically imported, so you usually only
  11. need to ``import megengine._internal as mgb``.
  12. """
  13. import collections
  14. import json
  15. import os
  16. import sys
  17. import numpy as np
  18. from . import comp_graph_tools as cgtools
  19. from . import config, craniotome, dtype
  20. from . import global_init as _global_init
  21. from . import helper as _helper
  22. from . import mgb as _detail
  23. from . import opr, opr_param_defs, plugin
  24. from .exc import MegBrainError
  25. from .logconf import get_logger
  26. from .mgb import (
  27. CompGraph,
  28. CompNode,
  29. SharedND,
  30. SharedScalar,
  31. SymbolVar,
  32. TensorValueDumperContext,
  33. TensorValueLoaderContext,
  34. )
  35. from .mgb import as_comp_node as comp_node
  36. from .mgb_helper import SharedNDLazyInitializer, callback_lazycopy, copy_output
  37. from .plugin import CompGraphProfiler
  38. from .plugin import GlobalInfkernFinder as _GlobalInfkernFinder
  39. from .plugin import NumRangeChecker
  40. from .version import __version__, version_info
  41. if sys.version_info.major < 3:
  42. raise ImportError("megbrain requires python 3")
  43. class ProxySharedNDAndSymbolVar(_detail.SymbolVar):
  44. """this is a :class:`.SymbolVar` with a corresponding :class:`.SharedND`.
  45. It can participate in graph computating and also provides :meth:`set_value`
  46. and :meth:`get_value`. It should be constructed by :func:`make_shared`.
  47. """
  48. __shared_nd = None
  49. __kwargs = None
  50. def __init__(self, snd, comp_graph, name, **kwargs):
  51. self.__shared_nd = snd
  52. self.__kwargs = kwargs
  53. self.this = snd.symvar(comp_graph=comp_graph, name=name, **kwargs).this
  54. def set_value(self, v, **kwargs):
  55. ret = self.__shared_nd.set_value(v, **kwargs)
  56. self._reeval_if_eager_eval()
  57. return ret
  58. def get_value(self):
  59. return self.__shared_nd.get_value()
  60. def reset_zero(self):
  61. self.__shared_nd.reset_zero()
  62. def make_shared(
  63. comp_node,
  64. *,
  65. dtype=None,
  66. shape=None,
  67. value=None,
  68. comp_graph=None,
  69. name=None,
  70. volatile=None
  71. ):
  72. """make a shared tensor which is stored on device and could be modified
  73. later, either as a :class:`.SymbolVar` or a :class:`.SharedND` object
  74. :param comp_node: computing node
  75. :type comp_node: :class:`.CompNode`
  76. :param dtype: data type; if it is None, then dtype of value would be used
  77. if value is not None, and float32 would be used as default dtype if
  78. value is None
  79. :type dtype: :class:`numpy.dtype` compatible
  80. :param value: initializing value
  81. :type value: None or :class:`numpy.ndarray`
  82. :param comp_graph: the computing graph to which this shared value should
  83. belong; if provided, the retuned object could be used as a
  84. :class:`.SymbolVar`
  85. :type comp_graph: None or :class:`.CompGraph`
  86. :param name: node name to be used in computing graph; only meaningful if
  87. *comp_graph* is provided
  88. :param volatile: if *comp_graph* is given then *volatile* indicates whether
  89. shape or mem ptr of this SharedND can be changed
  90. :rtype: :class:`.SharedND` if *comp_graph* is not given; or
  91. :class:`ProxySharedNDAndSymbolVar` otherwise
  92. """
  93. if dtype is None:
  94. if value is not None:
  95. value = np.ascontiguousarray(value)
  96. dtype = to_mgb_supported_dtype(value.dtype)
  97. else:
  98. dtype = np.float32
  99. comp_node = _detail.as_comp_node(comp_node)
  100. rst = _detail.SharedND(comp_node, dtype)
  101. if value is not None:
  102. assert shape is None, "could not provide both value and shape"
  103. rst.set_value(value)
  104. elif shape is not None:
  105. rst._set_init_shape(shape)
  106. if comp_graph is None:
  107. assert name is None and volatile is None
  108. return rst
  109. assert isinstance(comp_graph, CompGraph), "expect CompGraph but got {}".format(
  110. comp_graph
  111. )
  112. if volatile is None:
  113. volatile = False
  114. else:
  115. assert isinstance(volatile, bool)
  116. return ProxySharedNDAndSymbolVar(rst, comp_graph, name, volatile=volatile)
  117. def make_immutable(comp_node, comp_graph, value, *, dtype=None, name=None):
  118. """make a graph node containing an immutable tensor from host tensor value
  119. :param dtype: required data type; if not None, the data would be converted
  120. to that type; otherwise
  121. """
  122. comp_node = _detail.as_comp_node(comp_node)
  123. assert isinstance(
  124. comp_graph, _detail.CompGraph
  125. ), "expect CompGraph but got {!r}".format(comp_graph)
  126. config = _detail.make_opr_config(name, comp_node)
  127. return _helper.cvt_opr_result(
  128. _detail._make_immutable(comp_graph, value, dtype, config)
  129. )
  130. def make_arg(
  131. comp_node,
  132. comp_graph,
  133. *,
  134. dtype=np.float32,
  135. shape=None,
  136. name=None,
  137. value=None,
  138. enable_static_infer=True
  139. ):
  140. """make an argument to be passed to compiled function during runtime;
  141. :type shape: None or tuple of int
  142. :param shape: expected tensor shape to be used for shape inferring; actual
  143. tesor shape could be different
  144. :type name: str
  145. :param name: name of the generated var node
  146. :type value: None or ndarray-compatible
  147. :param value: initial value used for static inference; if not given, static
  148. infer would be deferred to first graph execution
  149. :param enable_static_infer: whether to enable static inference for this var
  150. """
  151. host_val = mgb._HostSharedND(comp_node, dtype)
  152. if value is not None:
  153. value = np.ascontiguousarray(value, dtype=dtype)
  154. if shape is None:
  155. shape = value.shape
  156. else:
  157. assert shape == value.shape
  158. if shape is not None:
  159. host_val._resize(shape)
  160. if value is not None:
  161. host_val.set_value(value)
  162. return _helper.cvt_opr_result(
  163. ProxySharedNDAndSymbolVar(
  164. host_val, comp_graph, name, enable_static_infer=enable_static_infer
  165. )
  166. )
  167. def comp_graph(*, extra_opts=None, check_env_var=True):
  168. """allocate a new computing graph
  169. :param extra_opts: extra options to be set; would be updated (modified
  170. inplace) from ``MGB_COMP_GRAPH_OPT`` environment var. See
  171. :func:`.set_comp_graph_option` for list of supported options.
  172. :type extra_opts: dict
  173. :param check_env_var: whether to check environment vars
  174. :type check_env_var: bool
  175. :return: the comp graph object
  176. :rtype: :class:`.CompGraph`
  177. """
  178. cg = _detail.CompGraph()
  179. if extra_opts is None:
  180. extra_opts = {}
  181. if check_env_var:
  182. setting = os.getenv("MGB_COMP_GRAPH_OPT")
  183. if setting:
  184. for item in setting.split(";"):
  185. k, v = item.split("=", 1)
  186. extra_opts.setdefault(k, v)
  187. get_logger().warning(
  188. "set comp graph option from env: {}".format(extra_opts)
  189. )
  190. user_data = os.getenv("MGB_COMP_GRAPH_USER_DATA")
  191. if user_data:
  192. storage = cg.user_data
  193. for ud in user_data.split(";"):
  194. k, v = ud.split("=", 1)
  195. storage[k] = eval(v)
  196. _GlobalInfkernFinder.add_graph(cg)
  197. for k, v in extra_opts.items():
  198. cg.set_option(k, v)
  199. return cg
  200. def grad(
  201. target, wrt, warn_mid_wrt=True, use_virtual_grad=None, return_zero_for_nodep=True
  202. ):
  203. r"""compute symbolic grad
  204. :param target: grad target var
  205. :type target: :class:`.SymbolVar`
  206. :param wrt: with respect to which to compute the grad
  207. :type wrt: :class:`.SymbolVar` or Iterable[SymbolVar]
  208. :param warn_mid_wrt: whether to give warning if *wrt* is not endpoint
  209. :type warn_mid_wrt: bool
  210. :param use_virtual_grad: whether to use virtual grad opr, so fwd graph can
  211. be optimized before applying grad; if ``None`` is given, then virtual
  212. grad would be used if ``graph_opt_level >= 2``
  213. :type use_virtual_grad: :class:`bool` or ``None``
  214. :param return_zero_for_nodep: if *target* does not depend on *wrt*, set to True to return
  215. a zero-valued `.SymbolVar` rather than ``None``; can't be set to False when using
  216. virtual grad opr.
  217. :type return_zero_for_nodep: bool
  218. :rtype: :class:`.SymbolVar` or None
  219. :return: :math:`\frac{\partial\text{target}}{\partial\text{wrt}}`
  220. """
  221. if use_virtual_grad is None:
  222. use_virtual_grad = -1
  223. else:
  224. use_virtual_grad = 1 if use_virtual_grad else 0
  225. if isinstance(wrt, SymbolVar):
  226. wrts = [
  227. wrt,
  228. ]
  229. else:
  230. wrts = wrt
  231. assert isinstance(wrts, collections.Iterable)
  232. # return a invalid SymbolVar (with nullptr VarNode*) when return_zero_for_nodep is False
  233. # and target doesn't depend on wrt
  234. grads = _detail._grad(
  235. target, wrts, bool(warn_mid_wrt), use_virtual_grad, return_zero_for_nodep
  236. )
  237. grads = list(grads)
  238. for i in range(len(grads)):
  239. if not grads[i].valid:
  240. assert (
  241. not return_zero_for_nodep
  242. ), "invalid grad SymbolVar: target={}, wrt={}".format(target, wrts[i])
  243. grads[i] = None
  244. if len(grads) == 1:
  245. grads = grads[0]
  246. return grads
  247. def current_grad_target(comp_graph):
  248. """get current target var to compute grad, used for implementing custom
  249. gradient"""
  250. return _detail._current_grad_target(comp_graph)
  251. def inter_graph_trans_var(dest_graph, src):
  252. """get the corresponding var of *src* in *dest_graph*; assuming
  253. *dest_graph* is a copy of owner graph of *src*; usually used in callback of
  254. set_grad to get grad of vars in loop
  255. :param dest_graph: target computing graph
  256. :type dest_graph: :class:`.CompGraph`
  257. :param src: source var node
  258. :type src: :class:`.SymbolVar`
  259. :return: corresponding var in *dest_graph*
  260. :rtype: :class:`.SymbolVar`
  261. """
  262. return _detail._inter_graph_trans_var(dest_graph, src)
  263. def get_graph_optimizer_replaced_var(src):
  264. """get optimized var corresponding to given var; usually used in callback
  265. of set_grad to get grad w.r.t. some var
  266. :param src: source var node
  267. :type src: :class:`.SymbolVar`
  268. :rtype: :class:`.SymbolVar`
  269. """
  270. return _detail._get_graph_optimizer_replaced_var(src)
  271. CompGraphSerializationResult = collections.namedtuple(
  272. "CompGraphSerializationResult",
  273. [
  274. "nr_opr",
  275. "tot_bytes",
  276. "tensor_value_bytes",
  277. "content_hash",
  278. "inputs",
  279. "outputs",
  280. "params",
  281. ],
  282. )
  283. def serialize_comp_graph_to_file(
  284. fpath,
  285. output_vars,
  286. *,
  287. keep_var_name=1,
  288. keep_param_name=False,
  289. keep_opr_priority=False,
  290. tensor_value_dumper=None,
  291. output_strip_info=False,
  292. append=False,
  293. format=None,
  294. **kwargs
  295. ):
  296. """serialize this computing graph and write result to a file. Note:
  297. ``kwargs`` exists for backward compatibility; there is no additional
  298. arguments.
  299. :parma fpath: path for the output file
  300. :type fpath: ``str``
  301. :param output_vars: output variables that need to be retrieved when
  302. deserializing
  303. .. note::
  304. The underlying C++ API only accepts a var list. If a dict is given,
  305. the vars would be renamed to given names.
  306. :type output_vars: dict(name => :class:`.SymbolVar`), or a list of vars
  307. :param keep_var_name: level for keeping variable names:
  308. * 0: none of the names are kept
  309. * 1: keep names of output vars
  310. * 2: keep names of all (output and internal) vars
  311. :param keep_param_name: whether to keep param names, so param values can be
  312. easily manipulated after loading model
  313. :param keep_opr_priority: whether to keep priority setting for operators
  314. :param tensor_value_dumper: a callable to dump tensor values; it should
  315. only write the tensor value without layout information. It would be
  316. given a :class:`.TensorValueDumperContext` object as its sole argument.
  317. :param output_strip_info: if set to True, then a json file containing
  318. information for code strip would be written to ``fpath+'.json'``
  319. :param append: whether to open output file in append mode
  320. :return: an instance of namedtuple :class:`CompGraphSerializationResult`,
  321. whose fields are:
  322. * ``nr_opr`` number of operators dumped
  323. * ``tot_bytes`` total bytes for the whole graph
  324. * ``tensor_value_bytes`` bytes consumed for dumping tensor values
  325. * ``inputs`` names of input tensors
  326. * ``params`` list of names of dumped params
  327. * ``outputs`` names of output vars
  328. :param format: serialization format of the resulting model, should be either
  329. "mdl" or "fbs"; none means default.
  330. :type format: ``str``
  331. """
  332. assert isinstance(fpath, str), "bad file path: {!r}".format(fpath)
  333. ov = _detail._VectorSymbolVar()
  334. SUPPORTED_FORMATS = {
  335. # default
  336. None: _detail.GraphDumpFormat_FLATBUFFERS,
  337. "fbs": _detail.GraphDumpFormat_FLATBUFFERS,
  338. }
  339. resolved_fmt = SUPPORTED_FORMATS.get(format, None)
  340. if resolved_fmt is None:
  341. raise ValueError(
  342. "unknown format {} requested, supported ones are {}".format(
  343. format, list(filter(None, SUPPORTED_FORMATS.keys()))
  344. )
  345. )
  346. if isinstance(output_vars, dict):
  347. used_vars = set()
  348. for name, var in output_vars.items():
  349. assert isinstance(var, _detail.SymbolVar), "bad output var: {!r}".format(
  350. var
  351. )
  352. assert var.id not in used_vars, (
  353. "var name is associated with a var object, so we can not have "
  354. "two names given to the same var: {}".format(var)
  355. )
  356. used_vars.add(var.id)
  357. var.rename(name)
  358. ov.push_back(var)
  359. else:
  360. for i in output_vars:
  361. assert isinstance(i, _detail.SymbolVar), "bad output var: {!r}".format(i)
  362. ov.push_back(i)
  363. if tensor_value_dumper is not None:
  364. assert isinstance(tensor_value_dumper, collections.Callable)
  365. class Callback(_detail._TensorValueDumperCallback):
  366. def call(self, ctx, *, _f=tensor_value_dumper):
  367. _f(ctx)
  368. tensor_value_dumper = Callback()
  369. # for backward compatibility
  370. mangle_opr_name = kwargs.pop("mangle_opr_name", ov)
  371. if mangle_opr_name is not ov:
  372. get_logger().warning("mangle_opr_name is deprecated; use keep_var_name instead")
  373. keep_var_name = 1 if mangle_opr_name else 2
  374. mangle_param_name = kwargs.pop("mangle_param_name", ov)
  375. assert (
  376. not kwargs
  377. ), "extra kwargs provided to serialize_comp_graph_to_file: {}".format(kwargs)
  378. if mangle_param_name is not ov:
  379. get_logger().warning(
  380. "mangle_param_name is deprecated; use keep_param_name instead"
  381. )
  382. keep_param_name = not mangle_param_name
  383. inputs = _detail._VectorString()
  384. outputs = _detail._VectorString()
  385. params = _detail._VectorString()
  386. stat = _detail._VectorSizeT()
  387. _detail._serialize_comp_graph_to_file(
  388. fpath,
  389. append,
  390. resolved_fmt,
  391. ov,
  392. keep_var_name,
  393. keep_param_name,
  394. keep_opr_priority,
  395. tensor_value_dumper,
  396. stat,
  397. inputs,
  398. outputs,
  399. params,
  400. )
  401. dump_ret = CompGraphSerializationResult(
  402. *stat, list(inputs), list(outputs), list(params)
  403. )
  404. if output_strip_info:
  405. with open(fpath + ".json", "w") as fout:
  406. strip_info = _detail._get_info_for_strip(ov)
  407. strip_info_dict = json.loads(strip_info)
  408. strip_info_dict["hash"] = dump_ret.content_hash
  409. json.dump(strip_info_dict, fout)
  410. return dump_ret
  411. CompGraphLoadResult = collections.namedtuple(
  412. "CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"]
  413. )
  414. def load_comp_graph_from_file(
  415. fpath, *, comp_node_mapper=None, tensor_value_loader=None
  416. ):
  417. """Load a serialized computing graph from file.
  418. :parma fpath: Path for the output file
  419. :type fpath: ``str``
  420. :param comp_node_mapper: A callable to modify comp node locator, takes old
  421. locator as argument and returns new locator.
  422. :type comp_node_mapper: Callable[[str], str]
  423. :param tensor_value_loader: A callable to load tensor values. It should
  424. read the tensor value with the given shape and dtype and return it as
  425. NumPy ndarray. It would be given a :class:`.TensorValueLoaderContext`
  426. object as its sole argument.
  427. :type tensor_value_loader: Callable[[TensorValueLoaderContext], numpy.ndarray]
  428. :return: An instance of namedtuple :class:`CompGraphLoadResult`,
  429. whose fields are:
  430. * ``graph`` loaded CompGraph
  431. * ``output_vars_dict`` A Python dict, mapping name to output SymbolVar
  432. * ``output_vars_list`` A Python list, containing output vars in the
  433. order passed to serialize_comp_graph_to_file
  434. """
  435. assert isinstance(fpath, str), "bad file path: {!r}".format(fpath)
  436. if comp_node_mapper is not None:
  437. assert isinstance(comp_node_mapper, collections.Callable)
  438. class Callback(_detail._CompNodeMapperCallback):
  439. def call(self, desc, *, _f=comp_node_mapper):
  440. return _f(desc)
  441. comp_node_mapper = Callback()
  442. if tensor_value_loader is not None:
  443. assert isinstance(tensor_value_loader, collections.Callable)
  444. class Callback(_detail._TensorValueLoaderCallback):
  445. def call(self, ctx, *, _f=tensor_value_loader):
  446. return _f(ctx)
  447. tensor_value_loader = Callback()
  448. output_vars_map = _detail._VectorPairStringSymbolVar()
  449. output_vars_list = _detail._VectorSymbolVar()
  450. cg = _detail._load_comp_graph_from_file(
  451. fpath, comp_node_mapper, tensor_value_loader, output_vars_map, output_vars_list
  452. )
  453. return CompGraphLoadResult(cg, dict(list(output_vars_map)), list(output_vars_list))
  454. def optimize_for_inference(
  455. output_vars,
  456. *,
  457. f16_io_f32_comp=False,
  458. f16_io_comp=False,
  459. use_nhwcd4=False,
  460. fuse_conv_bias_nonlinearity=False,
  461. use_nchw32=False,
  462. fuse_conv_bias_with_z=False,
  463. use_nchw88=False,
  464. use_nchw44=False,
  465. use_chwn4=False
  466. ):
  467. """optimize computing graph for inference
  468. This applies a predefined set of optimization passes. Refer to the mnist
  469. sdk example and C++ code for fine-grained control.
  470. :param output_vars: output symvars
  471. :type output_vars: list of :class:`.SymbolVar`
  472. :param f16_io_f32_comp: whether to use float16 for I/O between oprs and use
  473. float32 as internal computation precision. Note the output var would be
  474. changed to float16
  475. :param f16_io_comp: whether to use float16 for both I/O and computation
  476. precision
  477. :param use_nhwcd4: whether to use NHWCD4 data format. This is faster on some
  478. OpenCL devices
  479. :param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  480. into one opr. This is supported only in NHWCD4 format.
  481. :param use_nchw88: whether to use NCHW88 tensor format. This maybe faster some
  482. times.
  483. :param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some
  484. times.
  485. :param use_nchw32: whether to use NCHW32 tensor format. Mainly used for
  486. nvidia tensorcore.
  487. :param use_chwn4: whether to use CHWN4 tensor format. Mainly used for
  488. nvidia tensorcore.
  489. :return: list of transformed vars corresponding to given output vars
  490. """
  491. assert isinstance(output_vars, (list, tuple))
  492. opt = _detail._OptimizeForInferenceOptions()
  493. settings = locals()
  494. for i in [
  495. "f16_io_f32_comp",
  496. "f16_io_comp",
  497. "fuse_conv_bias_nonlinearity",
  498. "fuse_conv_bias_with_z",
  499. ]:
  500. if settings[i]:
  501. getattr(opt, "enable_{}".format(i))()
  502. layout_tranform = None
  503. for k, v in {
  504. "use_nhwcd4": "nhwcd4",
  505. "use_nchw32": "nchw32",
  506. "use_nchw88": "nchw88",
  507. "use_nchw44": "nchw44",
  508. "use_chwn4": "chwn4",
  509. }.items():
  510. if settings[k]:
  511. assert (
  512. not layout_tranform
  513. ), "Only one layout transform supported, both {} and {}".format(
  514. layout_tranform, k
  515. )
  516. getattr(opt, "enable_{}".format(v))()
  517. layout_tranform = k
  518. vec = _detail._VectorSymbolVar()
  519. for i in output_vars:
  520. assert isinstance(i, _detail.SymbolVar), "bad var: {}".format(i)
  521. vec.push_back(i)
  522. return list(_detail._optimize_for_inference(vec, opt))
  523. def get_opr_fp_graph_exec(comp_graph, output_vars):
  524. """get opr footprint and graph exec info
  525. This function will recompile the compute graph, the AsyncExecutable compiled
  526. before will be invalid.
  527. :param comp_graph: ComputingGraph
  528. :param output_vars: list of :class:'.SymbolVar'
  529. """
  530. assert isinstance(output_vars, (list, tuple))
  531. vec = _detail._VectorSymbolVar()
  532. for i in output_vars:
  533. assert isinstance(i, _detail.SymbolVar), "bad var: {}".format(i)
  534. vec.push_back(i)
  535. return json.loads(_detail._get_opr_fp_graph_exec(comp_graph, output_vars))
  536. def to_mgb_supported_dtype(dtype_):
  537. """get the dtype supported by megbrain nearest to given dtype"""
  538. if (
  539. dtype.is_lowbit(dtype_)
  540. or dtype.is_quantize(dtype_)
  541. or dtype.is_bfloat16(dtype_)
  542. ):
  543. return dtype_
  544. return _detail._to_mgb_supported_dtype(dtype_)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台