You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import ABCMeta, abstractmethod
  9. from collections import OrderedDict
  10. from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
  11. import numpy as np
  12. from ..core.tensor.utils import make_shape_tuple
  13. from ..logger import get_logger
  14. from ..tensor import Parameter, Tensor
  15. from ..utils.deprecation import deprecated
  16. from ..utils.hook import HookHandler
  17. from ..utils.naming import auto_naming
  18. logger = get_logger(__name__)
  19. def _expand_structure(prefix, obj):
  20. if isinstance(obj, (Tensor, Module)):
  21. return [(prefix, obj)]
  22. elif isinstance(obj, (list, tuple, dict)):
  23. ret = []
  24. if isinstance(obj, dict):
  25. targets = ((k, obj[k]) for k in sorted(obj))
  26. else:
  27. targets = ((str(k), v) for k, v in enumerate(obj))
  28. for k, o in targets:
  29. sub_ret = _expand_structure(k, o)
  30. if sub_ret and not isinstance(k, str):
  31. raise AssertionError(
  32. "keys for Tensor and Module must be str, error key: {}".format(k)
  33. )
  34. for kt, vt in sub_ret:
  35. ret.extend([(prefix + "." + kt, vt)])
  36. return ret
  37. else:
  38. return []
  39. def _access_structure(obj, key, callback=None):
  40. key_list = key.split(".")
  41. cur = obj
  42. parent = None
  43. for k in key_list:
  44. parent = cur
  45. if isinstance(cur, (Tensor, Module)):
  46. cur = getattr(cur, k)
  47. elif isinstance(cur, (list, tuple)):
  48. k = int(k)
  49. cur = cur[k]
  50. elif isinstance(cur, dict):
  51. cur = cur[k]
  52. else:
  53. raise ValueError(
  54. "Unsupport value type {} to access attribute".format(type(cur))
  55. )
  56. return callback(parent, k, cur)
  57. def _is_parameter(obj):
  58. return isinstance(obj, Parameter)
  59. def _is_buffer(obj):
  60. return isinstance(obj, Tensor) and not isinstance(obj, Parameter)
  61. def _is_module(obj):
  62. return isinstance(obj, Module)
  63. def _get_XNorm_typeclass():
  64. from .batchnorm import _BatchNorm
  65. from .normalization import GroupNorm, InstanceNorm, LayerNorm
  66. XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm)
  67. return XNorm_types
  68. class Module(metaclass=ABCMeta):
  69. """
  70. Base Module class.
  71. """
  72. def __init__(self, name=None):
  73. """
  74. :param name: module's name, can be initialized by the ``kwargs`` parameter
  75. of child class.
  76. """
  77. self._modules = []
  78. if name is not None:
  79. assert (
  80. isinstance(name, str) and name.strip()
  81. ), "Module's name must be a non-empty string"
  82. self.name = name
  83. # runtime attributes
  84. self.training = True
  85. self.quantize_disabled = False
  86. # hooks
  87. self._forward_pre_hooks = OrderedDict()
  88. self._forward_hooks = OrderedDict()
  89. # used for profiler and automatic naming
  90. self._name = "{anonymous}"
  91. @abstractmethod
  92. def forward(self, inputs):
  93. pass
  94. def register_forward_pre_hook(self, hook: Callable) -> HookHandler:
  95. """
  96. Registers a hook to handle forward inputs. `hook` should be a function.
  97. :param hook: a function that receive `module` and `inputs`, then return
  98. a modified `inputs` or `None`.
  99. :return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
  100. """
  101. return HookHandler(self._forward_pre_hooks, hook)
  102. def register_forward_hook(self, hook: Callable) -> HookHandler:
  103. """
  104. Registers a hook to handle forward results. `hook` should be a function that
  105. receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`.
  106. This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
  107. """
  108. return HookHandler(self._forward_hooks, hook)
  109. def __call__(self, *inputs, **kwargs):
  110. auto_naming.push_scope(self.name if self.name is not None else self._name)
  111. for hook in self._forward_pre_hooks.values():
  112. modified_inputs = hook(self, inputs)
  113. if modified_inputs is not None:
  114. if not isinstance(modified_inputs, tuple):
  115. modified_inputs = (modified_inputs,)
  116. inputs = modified_inputs
  117. outputs = self.forward(*inputs, **kwargs)
  118. for hook in self._forward_hooks.values():
  119. modified_outputs = hook(self, inputs, outputs)
  120. if modified_outputs is not None:
  121. outputs = modified_outputs
  122. auto_naming.pop_scope()
  123. return outputs
  124. def _flatten(
  125. self,
  126. *,
  127. recursive: bool = True,
  128. with_key: bool = False,
  129. with_parent: bool = False,
  130. prefix: Optional[str] = None,
  131. predicate: Callable[[Any], bool] = lambda _: True,
  132. seen: Optional[Set[int]] = None
  133. ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]:
  134. """
  135. Scans the module object and returns an iterable for the :class:`~.Tensor`
  136. and :class:`~.Module` attributes that agree with the ``predicate``. For multiple
  137. calls of this function with same arguments, the order of objects within the
  138. returned iterable is guaranteed to be identical, as long as all the involved
  139. module objects' ``__dict__`` does not change thoughout those calls.
  140. :param recursive: whether to recursively scan all the submodules.
  141. :param with_key: whether to yield keys along with yielded objects.
  142. :param with_parent: whether to yield ``self`` along with yielded objects.
  143. :param prefix: prefix appended to the yielded keys.
  144. :param predicate: the predication function applied to scanned objects.
  145. :param seen: a dict that records whether a module has been traversed yet.
  146. """
  147. if seen is None:
  148. seen = set([id(self)])
  149. module_dict = vars(self)
  150. _prefix = "" if prefix is None else prefix + "."
  151. for key in sorted(module_dict):
  152. for expanded_key, leaf in _expand_structure(key, module_dict[key]):
  153. leaf_id = id(leaf)
  154. if leaf_id in seen:
  155. continue
  156. seen.add(leaf_id)
  157. if predicate(leaf):
  158. if with_key and with_parent:
  159. yield _prefix + expanded_key, leaf, self
  160. elif with_key:
  161. yield _prefix + expanded_key, leaf
  162. elif with_parent:
  163. yield leaf, self
  164. else:
  165. yield leaf
  166. if recursive and isinstance(leaf, Module):
  167. yield from leaf._flatten(
  168. recursive=recursive,
  169. with_key=with_key,
  170. with_parent=with_parent,
  171. prefix=_prefix + expanded_key if with_key else None,
  172. predicate=predicate,
  173. seen=seen,
  174. )
  175. def parameters(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]:
  176. r"""
  177. Returns an iterable for the :class:`~.Parameter` of the module.
  178. :param recursive: If ``True``, returns all :class:`~.Parameter` within this
  179. module, else only returns :class:`~.Parameter` that are direct attributes
  180. of this module.
  181. """
  182. if "requires_grad" in kwargs:
  183. del kwargs["requires_grad"]
  184. logger.warning(
  185. "Tensor currently has no requires_grad attribute "
  186. "so requires_grad argument is ignored here"
  187. )
  188. def predicate(obj) -> bool:
  189. return _is_parameter(obj)
  190. yield from self._flatten(
  191. with_key=False, predicate=predicate, recursive=recursive, **kwargs
  192. )
  193. def named_parameters(
  194. self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
  195. ) -> Iterable[Tuple[str, Parameter]]:
  196. """
  197. Returns an iterable for key :class:`~.Parameter` pairs of the module, where
  198. ``key`` is the dotted path from this module to the :class:`~.Parameter`.
  199. :param prefix: prefix prepended to the keys.
  200. :param recursive: if ``True``, returns all :class:`~.Parameter` within this
  201. module, else only returns :class:`~.Parameter` that are direct attributes
  202. of this module.
  203. """
  204. if "requires_grad" in kwargs:
  205. del kwargs["requires_grad"]
  206. logger.warning(
  207. "Tensor currently has no requires_grad attribute "
  208. "so requires_grad argument is ignored here"
  209. )
  210. def predicate(obj) -> bool:
  211. return _is_parameter(obj)
  212. yield from self._flatten(
  213. with_key=True,
  214. prefix=prefix,
  215. predicate=predicate,
  216. recursive=recursive,
  217. **kwargs,
  218. )
  219. def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Tensor]:
  220. """
  221. Returns an iterable for the buffers of the module.
  222. Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`.
  223. :param recursive: if ``True``, returns all buffers within this
  224. module, else only returns buffers that are direct attributes
  225. of this module.
  226. """
  227. yield from self._flatten(
  228. with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs
  229. )
  230. def named_buffers(
  231. self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
  232. ) -> Iterable[Tuple[str, Tensor]]:
  233. """
  234. Returns an iterable for key buffer pairs of the module, where
  235. ``key`` is the dotted path from this module to the buffer.
  236. Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`.
  237. :param prefix: prefix prepended to the keys.
  238. :param recursive: if ``True``, returns all buffers within this
  239. module, else only returns buffers that are direct attributes
  240. of this module.
  241. """
  242. yield from self._flatten(
  243. with_key=True,
  244. prefix=prefix,
  245. predicate=_is_buffer,
  246. recursive=recursive,
  247. **kwargs,
  248. )
  249. def children(self, **kwargs) -> "Iterable[Module]":
  250. """
  251. Returns an iterable for all the submodules that are direct attributes of this
  252. module.
  253. """
  254. yield from self._flatten(
  255. with_key=False, predicate=_is_module, recursive=False, **kwargs
  256. )
  257. def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]":
  258. """
  259. Returns an iterable of key-submodule pairs for all the submodules that are
  260. direct attributes of this module, where 'key' is the attribute name of
  261. submodules.
  262. """
  263. yield from self._flatten(
  264. with_key=True, predicate=_is_module, recursive=False, **kwargs
  265. )
  266. def modules(self, **kwargs) -> "Iterable[Module]":
  267. """
  268. Returns an iterable for all the modules within this module, including itself.
  269. """
  270. if "with_parent" in kwargs and kwargs["with_parent"]:
  271. yield self, None
  272. else:
  273. yield self
  274. yield from self._flatten(with_key=False, predicate=_is_module, **kwargs)
  275. def named_modules(
  276. self, prefix: Optional[str] = None, **kwargs
  277. ) -> "Iterable[Tuple[str, Module]]":
  278. """
  279. Returns an iterable of key-module pairs for all the modules within this
  280. module, including itself, where 'key' is the dotted path from this module to the
  281. submodules.
  282. :param prefix: prefix prepended to the path.
  283. """
  284. if "with_parent" in kwargs and kwargs["with_parent"]:
  285. yield ("" if prefix is None else prefix), self, None
  286. else:
  287. yield ("" if prefix is None else prefix), self
  288. yield from self._flatten(
  289. with_key=True, prefix=prefix, predicate=_is_module, **kwargs
  290. )
  291. def apply(self, fn: "Callable[[Module], Any]") -> None:
  292. """
  293. Applies function ``fn`` to all the modules within this module, including
  294. itself.
  295. :param fn: the function to be applied on modules.
  296. """
  297. for it in self.modules():
  298. fn(it)
  299. @deprecated(version="1.0")
  300. def zero_grad(self) -> None:
  301. """
  302. Sets all parameters' grads to zero
  303. """
  304. for param in self.parameters():
  305. if param.grad is not None:
  306. param.grad.reset_zero()
  307. def train(self, mode: bool = True, recursive: bool = True) -> None:
  308. """
  309. Sets training mode of all the modules within this module (including itself) to
  310. ``mode``. This effectively sets the ``training`` attributes of those modules
  311. to ``mode``, but only has effect on certain modules (e.g.
  312. :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`)
  313. :param mode: the training mode to be set on modules.
  314. :param recursive: whether to recursively call submodules' ``train()``.
  315. """
  316. if not recursive:
  317. self.training = mode
  318. return
  319. def fn(module: Module) -> None:
  320. module.train(mode, recursive=False)
  321. self.apply(fn)
  322. def eval(self) -> None:
  323. """
  324. Sets training mode of all the modules within this module (including itself) to
  325. ``False``. See :meth:`~.Module.train` for details.
  326. """
  327. self.train(False)
  328. def disable_quantize(self, value=True):
  329. r"""
  330. Sets ``module``'s ``quantize_disabled`` attribute and return ``module``.
  331. Could be used as a decorator.
  332. """
  333. def fn(module: Module) -> None:
  334. module.quantize_disabled = value
  335. self.apply(fn)
  336. @deprecated(version="1.0")
  337. def replace_param(
  338. self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
  339. ):
  340. """
  341. Replaces module's parameters with ``params``, used by :class:`~.ParamPack` to
  342. speedup multimachine training.
  343. """
  344. offset = 0
  345. if seen is None:
  346. seen = set([id(self)])
  347. module_dict = vars(self)
  348. for key in sorted(module_dict):
  349. hash_id = id(module_dict[key])
  350. if hash_id in seen:
  351. continue
  352. seen.add(hash_id)
  353. if isinstance(module_dict[key], Parameter):
  354. if start_pos + offset in params:
  355. assert make_shape_tuple(module_dict[key].shape) == make_shape_tuple(
  356. params[start_pos + offset].shape
  357. )
  358. module_dict[key] = params[start_pos + offset]
  359. offset += 1
  360. if isinstance(module_dict[key], Module):
  361. offset += module_dict[key].replace_param(
  362. params, start_pos + offset, seen
  363. )
  364. return offset
  365. def state_dict(self, rst=None, prefix="", keep_var=False):
  366. _rst = self._state_dict(rst=rst, prefix=prefix, keep_var=keep_var)
  367. rst = OrderedDict()
  368. XNorm_typeclass = _get_XNorm_typeclass()
  369. for (module_type, k), v in _rst.items():
  370. # for performance reasons, parameters in XNorm (e.g., BatchNorm2d) are 4-dim tensors,
  371. # however they will be reshaped to 1-dim tensors before returned by `statr_dict()`
  372. if issubclass(module_type, XNorm_typeclass):
  373. v = v.reshape(-1)
  374. rst[k] = v
  375. return rst
  376. def _state_dict(self, rst=None, prefix="", keep_var=False):
  377. r"""
  378. Returns a dictionary containing whole states of the module.
  379. """
  380. def is_state(obj):
  381. return _is_parameter(obj) or _is_buffer(obj)
  382. module_type = self.__class__
  383. if rst is None:
  384. rst = OrderedDict()
  385. for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state):
  386. assert prefix + k not in rst, "duplicated state: {}".format(k)
  387. if keep_var:
  388. rst[(module_type, prefix + k)] = v
  389. else:
  390. rst[(module_type, prefix + k)] = v.numpy()
  391. for k, submodule in self._flatten(
  392. recursive=False,
  393. with_key=True,
  394. predicate=lambda obj: isinstance(obj, Module),
  395. ):
  396. submodule.state_dict(rst, prefix + k + ".", keep_var)
  397. return rst
  398. def load_state_dict(
  399. self,
  400. state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]],
  401. strict=True,
  402. ):
  403. r"""
  404. Loads a given dictionary created by :func:`state_dict` into this module.
  405. If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys
  406. returned by :func:`state_dict`.
  407. Users can also pass a closure: ``Function[key: str, var: Tensor] -> Optional[np.ndarray]``
  408. as a `state_dict`, in order to handle complex situations. For example, load everything
  409. except for the final linear classifier:
  410. .. code-block::
  411. state_dict = {...} # Dict[str, np.ndarray]
  412. model.load_state_dict({
  413. k: None if k.startswith('fc') else v
  414. for k, v in state_dict.items()
  415. }, strict=False)
  416. Here returning ``None`` means skipping parameter ``k``.
  417. To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading:
  418. .. code-block::
  419. state_dict = {...}
  420. def reshape_accordingly(k, v):
  421. return state_dict[k].reshape(v.shape)
  422. model.load_state_dict(reshape_accordingly)
  423. We can also perform inplace re-initialization or pruning:
  424. .. code-block::
  425. def reinit_and_pruning(k, v):
  426. if 'bias' in k:
  427. M.init.zero_(v)
  428. if 'conv' in k:
  429. return v.numpy() * (np.abs(v.numpy()) > 1e-3).astype("float32)
  430. model.load_state_dict(reinit_and_pruning, strict=False)
  431. """
  432. unused = []
  433. if isinstance(state_dict, dict):
  434. unused = state_dict.keys()
  435. def closure(k, _): # var unused
  436. return state_dict[k] if k in state_dict else None
  437. elif callable(state_dict):
  438. closure = state_dict
  439. else:
  440. raise ValueError(
  441. "`state_dict` must load a dict or callable, got {}".format(
  442. type(state_dict)
  443. )
  444. )
  445. loaded, skipped = self._load_state_dict_with_closure(closure)
  446. unused = set(unused) - loaded
  447. if len(unused) != 0:
  448. if strict:
  449. raise KeyError(
  450. "Unused params violate `strict=True`, unused={}".format(unused)
  451. )
  452. else:
  453. logger.warning(
  454. "Unused params in `strict=False` mode, unused={}".format(unused)
  455. )
  456. if len(skipped) != 0:
  457. if strict:
  458. raise KeyError(
  459. "Missing params violate `strict=True`, missing={}".format(skipped)
  460. )
  461. else:
  462. logger.warning(
  463. "Missing params in `strict=False` mode, missing={}".format(skipped)
  464. )
  465. def _load_state_dict_with_closure(self, closure):
  466. """
  467. Advance state_dict load through callable ``closure`` whose signature is
  468. ``closure(key: str, var: Tensor) -> Union[np.ndarry, None]``
  469. """
  470. XNorm_typeclass = _get_XNorm_typeclass()
  471. assert callable(closure), "closure must be a function"
  472. loaded = []
  473. skipped = []
  474. local_state_dict = self._state_dict(keep_var=True)
  475. for (module_type, k), var in local_state_dict.items():
  476. to_be_load = closure(k, var)
  477. if to_be_load is None:
  478. skipped.append(k)
  479. continue
  480. assert isinstance(
  481. to_be_load, np.ndarray
  482. ), "closure should return a `np.ndarray`, now `{}` get {}".format(
  483. k, to_be_load
  484. )
  485. var_shape = make_shape_tuple(var.shape)
  486. to_be_load_shape = make_shape_tuple(to_be_load.shape)
  487. if var_shape != to_be_load_shape:
  488. # weight and bias in BatchNorm1d, BatchNorm2d and SyncBatchNorm are 1-dim tensors in v1.0, and
  489. # since v1.1 they are 4-dim tensors. The following special rule for these modules preserves the
  490. # backward compatibility.
  491. if issubclass(module_type, XNorm_typeclass):
  492. if np.prod(var_shape) == np.prod(to_be_load_shape):
  493. to_be_load = to_be_load.reshape(var_shape)
  494. else:
  495. raise ValueError(
  496. "param `{}` size mismatch, should be {}, get {}".format(
  497. k, np.prod(var_shape), np.prod(to_be_load_shape)
  498. )
  499. )
  500. else:
  501. raise ValueError(
  502. "param `{}` shape mismatch, should be {}, get {}".format(
  503. k, var_shape, to_be_load_shape
  504. )
  505. )
  506. var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device))
  507. loaded.append(k)
  508. return set(loaded), set(skipped)
  509. def __getattribute__(self, name: str):
  510. value = super().__getattribute__(name)
  511. if name == "__dict__":
  512. return value
  513. for prefix, variable in _expand_structure(name, value):
  514. variable._name = prefix
  515. return value
  516. def __setattr__(self, name: str, value):
  517. if _is_module(value) or (
  518. isinstance(value, (list, tuple, dict)) and name != "_modules"
  519. ):
  520. modules = self.__dict__.get("_modules")
  521. if modules is None:
  522. raise AttributeError(
  523. "cannot assign module before Module.__init__() call"
  524. )
  525. if name not in self.__dict__:
  526. modules.append(name)
  527. super().__setattr__(name, value)
  528. def __delattr__(self, name: str):
  529. if name in self.__dict__ and _is_module(self.__dict__[name]):
  530. modules = self.__dict__.get("_modules")
  531. modules.remove(name)
  532. super().__delattr__(name)
  533. def _module_info_string(self) -> str:
  534. r"""
  535. Set the extra representation of the module.
  536. """
  537. return ""
  538. def __repr__(self):
  539. def add_indent(repr_str, num_spaces):
  540. s = repr_str.split("\n")
  541. # don't do anything for single-line stuff
  542. if len(s) == 1:
  543. return repr_str
  544. first = s.pop(0)
  545. s = [(num_spaces * " ") + line for line in s]
  546. s = "\n".join(s)
  547. s = first + "\n" + s
  548. return s
  549. extra_lines = []
  550. extra_repr = self._module_info_string()
  551. if extra_repr:
  552. extra_lines = extra_repr.split("\n")
  553. child_lines = []
  554. for name in self._modules:
  555. if _is_module(self.__dict__[name]):
  556. child_lines.append(
  557. "(" + name + "): " + add_indent(repr(self.__dict__[name]), 2)
  558. )
  559. else:
  560. for k, v in _expand_structure(name, self.__dict__[name]):
  561. if _is_module(v):
  562. child_lines.append("(" + k + "): " + add_indent(repr(v), 2))
  563. lines = extra_lines + child_lines
  564. main_str = self.__class__.__name__ + "("
  565. if lines:
  566. # simple one-liner info, which most builtin Modules will use
  567. if len(extra_lines) == 1 and not child_lines:
  568. main_str += extra_lines[0]
  569. else:
  570. main_str += "\n " + "\n ".join(lines) + "\n"
  571. main_str += ")"
  572. return main_str

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台