You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import collections
  9. import copy
  10. import functools
  11. import itertools
  12. import weakref
  13. from typing import Callable, Tuple, Union
  14. import numpy as np
  15. import megengine._internal as mgb
  16. from .graph import _use_default_if_none, get_default_graph
  17. def wrap_io_tensor(func):
  18. r"""A wrapper to make ``func`` compatible with functions in ``_internal.opr``.
  19. """
  20. @functools.wraps(func)
  21. def wrapper(*args, **kwargs):
  22. comp_graph = None
  23. for i in itertools.chain(args, kwargs.values()):
  24. if isinstance(i, Tensor) and i._comp_graph:
  25. comp_graph = i._comp_graph
  26. break
  27. else:
  28. comp_graph = get_default_graph()
  29. new_args = (
  30. arg._attach(comp_graph) if isinstance(arg, Tensor) else arg for arg in args
  31. )
  32. new_kwargs = {
  33. k: v._attach(comp_graph) if isinstance(v, Tensor) else v
  34. for k, v in kwargs.items()
  35. }
  36. ret = func(*new_args, **new_kwargs)
  37. if isinstance(ret, mgb.SymbolVar):
  38. ret = Tensor(ret)
  39. elif isinstance(ret, list):
  40. ret = [Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret]
  41. elif isinstance(ret, tuple):
  42. ret = tuple(Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret)
  43. return ret
  44. return wrapper
  45. def _wrap_symbolvar_binary_op(f):
  46. @functools.wraps(f)
  47. def wrapped(self, other):
  48. comp_graph = (
  49. isinstance(other, Tensor)
  50. and other._comp_graph
  51. or self._comp_graph
  52. or get_default_graph()
  53. )
  54. if isinstance(other, Tensor):
  55. other = other._attach(comp_graph)
  56. return Tensor(f(self._attach(comp_graph), other))
  57. return wrapped
  58. def _wrap_slice(inp: slice):
  59. r"""
  60. A wrapper to handle Tensor values in ``inp`` slice.
  61. """
  62. start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start
  63. stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop
  64. step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step
  65. return slice(start, stop, step)
  66. def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]):
  67. r"""
  68. A wrapper to handle Tensor values in ``idx``.
  69. """
  70. if not isinstance(idx, tuple):
  71. idx = (idx,)
  72. idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx)
  73. idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx)
  74. return idx
  75. class _MGBIndexWrapper:
  76. r"""
  77. A wrapper class to handle ``__getitem__`` for index containing Tensor values.
  78. :param dest: a destination Tensor to do indexing on.
  79. :param mgb_index: an ``_internal`` helper function indicating how to index.
  80. :param val: a optional Tensor parameter used for ``mgb_index``.
  81. """
  82. def __init__(self, dest: "Tensor", mgb_index: Callable, val=None):
  83. self.dest = dest
  84. self.val = val
  85. self.mgb_index = mgb_index
  86. def __getitem__(self, idx):
  87. if self.val is None:
  88. return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)(
  89. _wrap_idx(idx)
  90. )
  91. else:
  92. return wrap_io_tensor(
  93. self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__
  94. )(_wrap_idx(idx))
  95. class _Guard:
  96. r"""
  97. A wrapper class with custom ``__del__`` method calling ``deleter``.
  98. :param deleter: a function to be called in ``__del__``.
  99. """
  100. def __init__(self, deleter: Callable):
  101. self.deleter = deleter
  102. def __del__(self):
  103. self.deleter()
  104. class Tensor:
  105. r"""The main data container in MegEngine.
  106. Use :func:`~.tensor` to create a Tensor with existed data.
  107. """
  108. requires_grad = False
  109. grad = None
  110. def __init__(self, val=None, *, requires_grad=None):
  111. self._reset(val, requires_grad=requires_grad)
  112. self.q_dict = {"mode": None, "scale": None, "zero_point": None}
  113. def _reset(self, val=None, *, requires_grad=None):
  114. self.__sym_override = None
  115. if val is None:
  116. self.__val = None
  117. self.__sym = None
  118. elif isinstance(val, mgb.SharedND):
  119. self.__val = val
  120. self.__sym = None
  121. elif isinstance(val, mgb.SymbolVar):
  122. self.__val = None
  123. self.__sym = val
  124. else:
  125. raise TypeError("must be initialized with SymbolVar or SharedND")
  126. self.requires_grad = requires_grad
  127. def _as_tensor(self, obj):
  128. r"""Convert the data into a ``Tensor``. If the data is already a Tensor
  129. with the same dtype and device, no copy will be performed. Otherwise a
  130. new Tensor will be returned with computational graph retained.
  131. """
  132. if isinstance(obj, Tensor):
  133. return obj
  134. if isinstance(obj, mgb.SymbolVar):
  135. return Tensor(obj)
  136. if isinstance(obj, mgb.SharedScalar):
  137. return Tensor(obj._as_sym_var(self._comp_graph, self._comp_node))
  138. return tensor(data=obj, device=self.device)
  139. def numpy(self):
  140. r"""Return the tensor value in numpy.ndarray format.
  141. """
  142. if self.__val is not None:
  143. assert self.__sym is None
  144. return self.__val.get_value()
  145. if self.__sym is None:
  146. raise ValueError("uninitialized")
  147. if self.__sym.eager_val is not None:
  148. return self.__sym.eager_val.get_value()
  149. return self.__sym.inferred_value
  150. def item(self):
  151. r"""If tensor only has only one value, return it."""
  152. return self.numpy().item()
  153. def _attach(self, comp_graph, *, volatile=True):
  154. sym = self.__sym_override or self.__sym
  155. if sym:
  156. if sym.owner_graph != comp_graph:
  157. raise RuntimeError("internal error")
  158. return sym
  159. if self.__val:
  160. return self.__val.symvar(comp_graph, volatile=volatile)
  161. else:
  162. raise ValueError("uninitialized")
  163. @property
  164. def _symvar(self):
  165. if self.__sym_override:
  166. return self.__sym_override
  167. if self.__sym:
  168. assert not self.__val
  169. return self.__sym
  170. if not self.__val:
  171. raise ValueError("uninitialized")
  172. return self._attach(get_default_graph())
  173. def __mgb_symvar__(self, comp_graph=None, **_):
  174. if self.__sym_override:
  175. return self.__sym_override
  176. if self.__val and comp_graph:
  177. return self._attach(comp_graph)
  178. return self._symvar # read by mgb.opr
  179. def _override_symvar_during_trace(self, trace, symvar):
  180. assert self.__val and not self.__sym
  181. assert trace is type(trace)._active_instance
  182. deleters = trace._user_cache.setdefault(Tensor, set())
  183. self_ref = weakref.ref(self)
  184. def restore():
  185. self = self_ref()
  186. if self is not None:
  187. self.__sym_override = None
  188. deleters.add(_Guard(restore))
  189. self.__sym_override = symvar
  190. @property
  191. def dtype(self):
  192. r"""Return the data type of the tensor.
  193. """
  194. if self.__val is not None:
  195. return self.__val.dtype
  196. return self._symvar.dtype
  197. @dtype.setter
  198. def dtype(self, dtype: str = None):
  199. r"""Set the data type of the tensor.
  200. """
  201. if self.__val is not None:
  202. self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy())
  203. elif self.__sym_override is not None:
  204. self.__sym_override = self.__sym_override.astype(dtype)
  205. elif self.__sym is not None:
  206. self.__sym = self.__sym.astype(dtype)
  207. @property
  208. def name(self):
  209. r"""Get the tensor name, does not support Parameter and Buffer.
  210. """
  211. return self._symvar.name
  212. @name.setter
  213. def name(self, name: str = None):
  214. r"""Set the tensor name, does not support Parameter and Buffer.
  215. """
  216. if self.__val is not None:
  217. raise ValueError("name setting is not available for Parameter or Buffer.")
  218. if self.__sym_override is not None:
  219. self.__sym_override = self.__sym_override.rename(name)
  220. if self.__sym is not None:
  221. assert not self.__val
  222. self.__sym = self.__sym.rename(name)
  223. @property
  224. def _comp_node(self):
  225. if self.__val is not None:
  226. return self.__val.comp_node
  227. return self._symvar.comp_node
  228. device = _comp_node
  229. @property
  230. def _comp_graph(self):
  231. if self.__sym is not None:
  232. return self.__sym.owner_graph
  233. return None
  234. @property
  235. def shape(self):
  236. r"""Return an int tuple that is the shape/layout of the tensor.
  237. Could be invalid in static graph mode.
  238. """
  239. from ..jit import trace
  240. if trace._active_instance: # pylint: disable=protected-access
  241. # NOTE: this is an hack
  242. shape = mgb.opr.get_var_shape(self._symvar)
  243. return tuple(Tensor(shape[i]) for i in range(self.ndim))
  244. return self._symvar.imm_shape
  245. def set_value(self, value, *, sync=True, inplace=False, share=False):
  246. r"""Set value to the tensor.
  247. """
  248. if not self.__val:
  249. raise ValueError("not detached")
  250. if isinstance(value, Tensor):
  251. value = value.__val or value.__sym.eager_val
  252. self.__val.set_value(value, sync=sync, inplace=inplace, share=share)
  253. def fill(self, value):
  254. r"""Fills the tensor with the specified value.
  255. """
  256. self.set_value(np.full(self.shape, value, dtype=self.dtype))
  257. def reset_zero(self):
  258. r"""Reset the tensor and fills with zeros.
  259. """
  260. if not self.__val:
  261. raise ValueError("not detached")
  262. self.__val.reset_zero()
  263. def to(self, device):
  264. r"""Performs Tensor device conversion, returns Tensor with the specified device.
  265. """
  266. return wrap_io_tensor(mgb.opr.copy)(self, comp_node=device)
  267. # https://docs.python.org/3/reference/datamodel.html#object.__hash__
  268. # > If a class does not define an __eq__() method it should not define a
  269. # > __hash__() operation either
  270. __hash__ = None # type: ignore[assignment]
  271. def __eq__(self, rhs):
  272. rhs = self._as_tensor(rhs)
  273. return Tensor(self._symvar._binary_opr("EQ", rhs._symvar))
  274. def __ne__(self, rhs):
  275. return 1 - self.__eq__(rhs)
  276. def __len__(self):
  277. if self._symvar.eager_val is not None:
  278. return self._symvar.eager_val.shape[0]
  279. raise TypeError(
  280. "__len__ and __iter__ is not available for tensors on non eager graph."
  281. )
  282. __add__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__add__)
  283. __radd__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__radd__)
  284. __sub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__sub__)
  285. __rsub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rsub__)
  286. __mul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mul__)
  287. __rmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmul__)
  288. __matmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__matmul__)
  289. __rmatmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmatmul__)
  290. __lshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lshift__)
  291. __rshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rshift__)
  292. __truediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__truediv__)
  293. __rtruediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rtruediv__)
  294. __floordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__floordiv__)
  295. __rfloordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rfloordiv__)
  296. __mod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mod__)
  297. __rmod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmod__)
  298. __pow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__pow__)
  299. __rpow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rpow__)
  300. __lt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lt__)
  301. __gt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__gt__)
  302. __le__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__le__)
  303. __ge__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__ge__)
  304. __neg__ = wrap_io_tensor(mgb.SymbolVar.__neg__)
  305. sum = wrap_io_tensor(mgb.SymbolVar.sum)
  306. """
  307. Sum up the given tensors.
  308. """
  309. max = wrap_io_tensor(mgb.SymbolVar.max)
  310. """
  311. Return the maximum value of given tensor.
  312. """
  313. min = wrap_io_tensor(mgb.SymbolVar.min)
  314. """
  315. Return the minimum value of given tensor.
  316. """
  317. prod = wrap_io_tensor(mgb.SymbolVar.prod)
  318. """
  319. Return the product value of the given tensor.
  320. """
  321. mean = wrap_io_tensor(mgb.SymbolVar.mean)
  322. """
  323. Return the mean value of the given tensor.
  324. """
  325. dimshuffle = wrap_io_tensor(mgb.SymbolVar.dimshuffle)
  326. """
  327. See more details in :func:`~.functional.tensor.dimshuffle`.
  328. """
  329. astype = wrap_io_tensor(mgb.SymbolVar.astype)
  330. """
  331. Cast the tensor to a specified type.
  332. """
  333. def reshape(self, *target_shape):
  334. r"""Return a tensor which has given target shape
  335. Examples:
  336. .. testcode::
  337. import numpy as np
  338. from megengine import tensor
  339. inp = tensor(np.arange(1, 17, dtype=np.int32).reshape(4,4))
  340. out = tensor(np.arange(100, 116, dtype=np.int32).reshape(1,16))
  341. out = out.reshape(inp.shape)
  342. print(out.numpy())
  343. .. testoutput::
  344. [[100 101 102 103]
  345. [104 105 106 107]
  346. [108 109 110 111]
  347. [112 113 114 115]]
  348. """
  349. if isinstance(target_shape[0], tuple):
  350. if len(target_shape) > 1:
  351. raise ValueError("Only single tuple is accepted in reshape")
  352. target_shape = target_shape[0]
  353. target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape)
  354. return Tensor(mgb.SymbolVar.reshape(self._symvar, *target_shape))
  355. def broadcast(self, *target_shape):
  356. r"""Return a tesnor broadcasted by current tensor to given target shape
  357. Examples:
  358. .. testcode::
  359. import numpy as np
  360. from megengine import tensor
  361. data = tensor(np.arange(100, 104, dtype=np.int32).reshape(1,4))
  362. data = data.broadcast((4,4))
  363. print(data.numpy())
  364. .. testoutput::
  365. [[100 101 102 103]
  366. [100 101 102 103]
  367. [100 101 102 103]
  368. [100 101 102 103]]
  369. """
  370. if isinstance(target_shape[0], tuple):
  371. if len(target_shape) > 1:
  372. raise ValueError("Only single tuple is accepted in broadcast")
  373. target_shape = target_shape[0]
  374. target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape)
  375. return Tensor(mgb.SymbolVar.broadcast(self._symvar, *target_shape))
  376. # Prefer operators on Tensor instead of convert to numpy
  377. __array_priority__ = 1000
  378. # mgb indexing family
  379. def __getitem__(self, idx):
  380. return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx))
  381. def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
  382. r"""
  383. Return a object which supports using ``__getitem__`` to set subtensor.
  384. ``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``.
  385. """
  386. return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
  387. def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
  388. r"""
  389. Return a object which supports using ``__getitem__`` to increase subtensor.
  390. ``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``.
  391. """
  392. return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
  393. @property
  394. def ai(self) -> _MGBIndexWrapper:
  395. r"""
  396. Return a object which supports complex index method to get subtensor.
  397. Examples:
  398. .. testcode::
  399. from megengine import tensor
  400. a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
  401. print(a.ai[:, [2, 3]])
  402. Outputs:
  403. .. testoutput::
  404. Tensor([[ 2. 3.]
  405. [ 6. 7.]
  406. [10. 11.]
  407. [14. 15.]])
  408. """
  409. return _MGBIndexWrapper(self, mgb.opr.advanced_indexing)
  410. def set_ai(self, val: "Tensor") -> _MGBIndexWrapper:
  411. r"""
  412. Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
  413. """
  414. return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
  415. def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper:
  416. r"""
  417. Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
  418. """
  419. return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
  420. @property
  421. def mi(self) -> _MGBIndexWrapper:
  422. r"""
  423. Return a object which supports getting subtensor by
  424. the coordinates which is Cartesian product of given index.
  425. Examples:
  426. .. testcode::
  427. from megengine import tensor
  428. a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
  429. print(a.mi[[1, 2], [2, 3]])
  430. # is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]]
  431. # a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11
  432. Outputs:
  433. .. testoutput::
  434. Tensor([[ 6. 7.]
  435. [10. 11.]])
  436. """
  437. return _MGBIndexWrapper(self, mgb.opr.mesh_indexing)
  438. def set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  439. r"""
  440. Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
  441. """
  442. return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
  443. def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  444. r"""
  445. Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
  446. """
  447. return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
  448. @property
  449. def batched_mi(self) -> _MGBIndexWrapper:
  450. r"""
  451. Return a object which supports getting subtensor by
  452. batched mesh indexing.
  453. For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice.
  454. Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``.
  455. Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated.
  456. And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below.
  457. Examples:
  458. .. testcode::
  459. from megengine import tensor
  460. a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4)))
  461. print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]])
  462. # is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1])
  463. # and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1])
  464. # a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77
  465. print(a.batched_mi[:2, [[0],[1]], :2, :1])
  466. # is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]``
  467. Outputs:
  468. .. testoutput::
  469. Tensor([[[[ 0.]
  470. [ 4.]]]
  471. [[[73.]
  472. [77.]]]])
  473. Tensor([[[[ 0.]
  474. [ 4.]]]
  475. [[[64.]
  476. [68.]]]])
  477. """
  478. return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
  479. def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  480. r"""
  481. Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
  482. """
  483. return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
  484. def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  485. r"""
  486. Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
  487. """
  488. return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)
  489. def __array__(self, dtype=None):
  490. if dtype is None:
  491. return self.numpy()
  492. else:
  493. return self.numpy().astype(dtype, copy=False)
  494. def __int__(self):
  495. return int(self.item())
  496. def __index__(self):
  497. return int(self.item())
  498. def __round__(self, ndigits=0):
  499. if ndigits != 0:
  500. raise ValueError("ndigits must be 0 for Tensor.round")
  501. return Tensor(mgb.opr.elemwise([self._symvar], mode="ROUND"))
  502. round = __round__
  503. def sqrt(self):
  504. r"""Return a tensor that each element is the square root of its
  505. original value.
  506. """
  507. return Tensor(mgb.opr.sqrt(self._symvar))
  508. def shapeof(self, axis=None):
  509. r"""Return a Tensor that represent the shape of the tensor.
  510. """
  511. return Tensor(mgb.opr.get_var_shape(self._symvar, axis=axis))
  512. @property
  513. def ndim(self):
  514. r"""Return the number of dimensions of the tensor.
  515. """
  516. return len(self._symvar.imm_shape)
  517. def __repr__(self):
  518. piece = "Tensor("
  519. with np.printoptions(precision=4, suppress=True):
  520. piece += "{}".format(str(self.numpy()))
  521. if self.dtype != np.float32:
  522. piece += ", dtype={}".format(np.dtype(self.dtype).name)
  523. if self._comp_node.locator_logical != ("XPU", -1, 0):
  524. piece += ", device={}".format(self.device)
  525. piece += ")"
  526. return piece
  527. def __bool__(self):
  528. raise RuntimeError(
  529. "Tensor object should not be converted to bool or used in a if statement. Use .numpy(), int() or float() if you want to use its value in if statement, be aware that this may lead to incorrect result in non-eager mode."
  530. )
  531. def __getstate__(self):
  532. r""" __getstate__ will be called for pickle serialization or deep copy
  533. """
  534. assert (self.__val is not None) and (
  535. self.__sym is None
  536. ), "Only SharedND initialized Tensor can be serialized or deep copied"
  537. metadata = {"requires_grad": self.requires_grad}
  538. state = {
  539. "data": self.numpy(),
  540. "device": self.device,
  541. "dtype": self.dtype,
  542. "metadata": metadata,
  543. }
  544. return state
  545. def __setstate__(self, state):
  546. data = state.pop("data")
  547. device = state.pop("device")
  548. dtype = state.pop("dtype")
  549. metadata = state.pop("metadata", {})
  550. requires_grad = metadata.pop("requires_grad", None)
  551. snd = mgb.make_shared(device, value=data, dtype=dtype)
  552. self._reset(snd, requires_grad=requires_grad)
  553. def __deepcopy__(self, memo):
  554. """
  555. The default deepcopy will ignore other attributes except those defined at
  556. __getstate__ and __setstate__ method.
  557. So we need to add __deepcopy__ method to deepcopy correct attributes.
  558. """
  559. assert (self.__val is not None) and (
  560. self.__sym is None
  561. ), "Only SharedND initialized Tensor can be serialized or deep copied"
  562. cls = self.__class__
  563. result = cls.__new__(cls)
  564. memo[id(self)] = result
  565. for k, v in self.__dict__.items():
  566. setattr(result, k, copy.deepcopy(v, memo))
  567. return result
  568. def tensor(
  569. data: Union[list, np.ndarray] = None,
  570. *,
  571. dtype: str = None,
  572. device: mgb.CompNode = None,
  573. requires_grad: bool = None
  574. ):
  575. r"""A helper function to create a :class:`~.Tensor` using existing data.
  576. :param data: an existing data array, must be Python list, NumPy array or None.
  577. :param dtype: target Tensor data type, one of ``("uint8", "int8", "int16", "int32", "float32", "float16")``.
  578. :param device: target device for Tensor storing.
  579. :param requires_grad: whether its gradiant will be calculated during :meth:`~.Optimizer.backward`
  580. """
  581. supported_dtypes = ("uint8", "int8", "int16", "int32", "float32", "float16")
  582. if isinstance(data, Tensor):
  583. raise NotImplementedError
  584. if dtype is not None and np.dtype(dtype).name not in supported_dtypes:
  585. raise TypeError("unsupported dtype {}".format(dtype))
  586. if data is not None:
  587. if not isinstance(data, np.ndarray):
  588. data = np.array(data, dtype=dtype)
  589. # In order to accept tensor([1]),
  590. # Automaticlly convert to 32-bit number instead of numpy's default 64-bit when input data is not nparray.
  591. dtype = mgb.to_mgb_supported_dtype(data.dtype)
  592. if dtype is None:
  593. if data.dtype.name not in supported_dtypes:
  594. raise TypeError("unsupported dtype {}".format(data.dtype))
  595. device, _ = _use_default_if_none(device, None)
  596. shared_nd = mgb.make_shared(device, value=data, dtype=dtype)
  597. return Tensor(shared_nd, requires_grad=requires_grad)
  598. class TensorDict(collections.MutableMapping):
  599. r"""
  600. A helper class to maintain dict with Tensor key.
  601. """
  602. def __init__(self, *args, **kwargs):
  603. self.data = {}
  604. for i in args:
  605. self.update(i)
  606. self.update(**kwargs)
  607. class keyfn:
  608. def __new__(cls, x: Tensor):
  609. if not isinstance(x, Tensor):
  610. return x
  611. return super().__new__(cls)
  612. def __init__(self, x: Tensor):
  613. self._data = x # do not save id directly to make pickle work
  614. def __hash__(self):
  615. return id(self._data)
  616. def __eq__(self, other):
  617. return isinstance(other, type(self)) and id(self._data) == id(other._data)
  618. def __getitem__(self, key):
  619. _, v = self.data[self.keyfn(key)]
  620. return v
  621. def __setitem__(self, key, value):
  622. self.data[self.keyfn(key)] = key, value
  623. def __delitem__(self, key):
  624. del self.data[self.keyfn(key)]
  625. def __iter__(self):
  626. for _, (k, _) in self.data.items():
  627. yield k
  628. def __len__(self):
  629. return len(self.data)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台