You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import collections
  9. import functools
  10. import itertools
  11. import weakref
  12. from typing import Callable, Tuple, Union
  13. import numpy as np
  14. import megengine._internal as mgb
  15. from .graph import _use_default_if_none, get_default_graph
  16. def wrap_io_tensor(func):
  17. r"""A wrapper to make ``func`` compatible with functions in ``_internal.opr``.
  18. """
  19. @functools.wraps(func)
  20. def wrapper(*args, **kwargs):
  21. comp_graph = None
  22. for i in itertools.chain(args, kwargs.values()):
  23. if isinstance(i, Tensor) and i._comp_graph:
  24. comp_graph = i._comp_graph
  25. break
  26. else:
  27. comp_graph = get_default_graph()
  28. new_args = (
  29. arg._attach(comp_graph) if isinstance(arg, Tensor) else arg for arg in args
  30. )
  31. new_kwargs = {
  32. k: v._attach(comp_graph) if isinstance(v, Tensor) else v
  33. for k, v in kwargs.items()
  34. }
  35. ret = func(*new_args, **new_kwargs)
  36. if isinstance(ret, mgb.SymbolVar):
  37. ret = Tensor(ret)
  38. elif isinstance(ret, list):
  39. ret = [Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret]
  40. elif isinstance(ret, tuple):
  41. ret = tuple(Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret)
  42. return ret
  43. return wrapper
  44. def _wrap_symbolvar_binary_op(f):
  45. @functools.wraps(f)
  46. def wrapped(self, other):
  47. comp_graph = (
  48. isinstance(other, Tensor)
  49. and other._comp_graph
  50. or self._comp_graph
  51. or get_default_graph()
  52. )
  53. if isinstance(other, Tensor):
  54. other = other._attach(comp_graph)
  55. return Tensor(f(self._attach(comp_graph), other))
  56. return wrapped
  57. def _wrap_slice(inp: slice):
  58. r"""
  59. A wrapper to handle Tensor values in ``inp`` slice.
  60. """
  61. start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start
  62. stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop
  63. step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step
  64. return slice(start, stop, step)
  65. def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]):
  66. r"""
  67. A wrapper to handle Tensor values in ``idx``.
  68. """
  69. if not isinstance(idx, tuple):
  70. idx = (idx,)
  71. idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx)
  72. idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx)
  73. return idx
  74. class _MGBIndexWrapper:
  75. r"""
  76. A wrapper class to handle ``__getitem__`` for index containing Tensor values.
  77. :param dest: a destination Tensor to do indexing on.
  78. :param mgb_index: an ``_internal`` helper function indicating how to index.
  79. :param val: a optional Tensor parameter used for ``mgb_index``.
  80. """
  81. def __init__(self, dest: "Tensor", mgb_index: Callable, val=None):
  82. self.dest = dest
  83. self.val = val
  84. self.mgb_index = mgb_index
  85. def __getitem__(self, idx):
  86. if self.val is None:
  87. return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)(
  88. _wrap_idx(idx)
  89. )
  90. else:
  91. return wrap_io_tensor(
  92. self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__
  93. )(_wrap_idx(idx))
  94. class _Guard:
  95. r"""
  96. A wrapper class with custom ``__del__`` method calling ``deleter``.
  97. :param deleter: a function to be called in ``__del__``.
  98. """
  99. def __init__(self, deleter: Callable):
  100. self.deleter = deleter
  101. def __del__(self):
  102. self.deleter()
  103. class Tensor:
  104. r"""The main data container in MegEngine.
  105. Use :func:`~.tensor` to create a Tensor with existed data.
  106. """
  107. requires_grad = False
  108. grad = None
  109. def __init__(self, val=None, *, requires_grad=None):
  110. self._reset(val, requires_grad=requires_grad)
  111. def _reset(self, val=None, *, requires_grad=None):
  112. self.__sym_override = None
  113. if val is None:
  114. self.__val = None
  115. self.__sym = None
  116. elif isinstance(val, mgb.SharedND):
  117. self.__val = val
  118. self.__sym = None
  119. elif isinstance(val, mgb.SymbolVar):
  120. self.__val = None
  121. self.__sym = val
  122. else:
  123. raise TypeError("must be initialized with SymbolVar or SharedND")
  124. self.requires_grad = requires_grad
  125. def _as_tensor(self, obj):
  126. r"""Convert the data into a ``Tensor``. If the data is already a Tensor
  127. with the same dtype and device, no copy will be performed. Otherwise a
  128. new Tensor will be returned with computational graph retained.
  129. """
  130. if isinstance(obj, Tensor):
  131. return obj
  132. if isinstance(obj, mgb.SymbolVar):
  133. return Tensor(obj)
  134. if isinstance(obj, mgb.SharedScalar):
  135. return Tensor(obj._as_sym_var(self._comp_graph, self._comp_node))
  136. return tensor(data=obj, device=self.device)
  137. def numpy(self):
  138. r"""Return the tensor value in numpy.ndarray format.
  139. """
  140. if self.__val is not None:
  141. assert self.__sym is None
  142. return self.__val.get_value()
  143. if self.__sym is None:
  144. raise ValueError("uninitialized")
  145. if self.__sym.eager_val is not None:
  146. return self.__sym.eager_val.get_value()
  147. return self.__sym.inferred_value
  148. def item(self):
  149. r"""If tensor only has only one value, return it."""
  150. return self.numpy().item()
  151. def _attach(self, comp_graph, *, volatile=True):
  152. sym = self.__sym_override or self.__sym
  153. if sym:
  154. if sym.owner_graph != comp_graph:
  155. raise RuntimeError("internal error")
  156. return sym
  157. if self.__val:
  158. return self.__val.symvar(comp_graph, volatile=volatile)
  159. else:
  160. raise ValueError("uninitialized")
  161. @property
  162. def _symvar(self):
  163. if self.__sym_override:
  164. return self.__sym_override
  165. if self.__sym:
  166. assert not self.__val
  167. return self.__sym
  168. if not self.__val:
  169. raise ValueError("uninitialized")
  170. return self._attach(get_default_graph())
  171. def __mgb_symvar__(self, comp_graph=None, **_):
  172. if self.__sym_override:
  173. return self.__sym_override
  174. if self.__val and comp_graph:
  175. return self._attach(comp_graph)
  176. return self._symvar # read by mgb.opr
  177. def _override_symvar_during_trace(self, trace, symvar):
  178. assert self.__val and not self.__sym
  179. assert trace is type(trace)._active_instance
  180. deleters = trace._user_cache.setdefault(Tensor, set())
  181. self_ref = weakref.ref(self)
  182. def restore():
  183. self = self_ref()
  184. if self is not None:
  185. self.__sym_override = None
  186. deleters.add(_Guard(restore))
  187. self.__sym_override = symvar
  188. @property
  189. def dtype(self):
  190. r"""Return the data type of the tensor.
  191. """
  192. if self.__val is not None:
  193. return self.__val.dtype
  194. return self._symvar.dtype
  195. def set_dtype(self, dtype: str = None):
  196. r"""Set the data type of the tensor.
  197. """
  198. if self.__val is not None:
  199. self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy())
  200. elif self.__sym is not None:
  201. self.__sym = self.__sym.astype(dtype)
  202. @property
  203. def _comp_node(self):
  204. if self.__val is not None:
  205. return self.__val.comp_node
  206. return self._symvar.comp_node
  207. device = _comp_node
  208. @property
  209. def _comp_graph(self):
  210. if self.__sym is not None:
  211. return self.__sym.owner_graph
  212. return None
  213. @property
  214. def shape(self):
  215. r"""Return an int tuple that is the shape/layout of the tensor.
  216. Could be invalid in static graph mode.
  217. """
  218. from ..jit import trace
  219. if trace._active_instance: # pylint: disable=protected-access
  220. # NOTE: this is an hack
  221. shape = mgb.opr.get_var_shape(self._symvar)
  222. return tuple(Tensor(shape[i]) for i in range(self.ndim))
  223. return self._symvar.imm_shape
  224. def set_value(self, value, *, sync=True, inplace=False, share=False):
  225. r"""Set value to the tensor.
  226. """
  227. if not self.__val:
  228. raise ValueError("not detached")
  229. if isinstance(value, Tensor):
  230. value = value.__val or value.__sym.eager_val
  231. self.__val.set_value(value, sync=sync, inplace=inplace, share=share)
  232. def fill(self, value):
  233. r"""Fills the tensor with the specified value.
  234. """
  235. self.set_value(np.full(self.shape, value, dtype=self.dtype))
  236. def reset_zero(self):
  237. r"""Reset the tensor and fills with zeros.
  238. """
  239. if not self.__val:
  240. raise ValueError("not detached")
  241. self.__val.reset_zero()
  242. def to(self, device):
  243. r"""Performs Tensor device conversion, returns Tensor with the specified device.
  244. """
  245. return wrap_io_tensor(mgb.opr.copy)(self, comp_node=device)
  246. # https://docs.python.org/3/reference/datamodel.html#object.__hash__
  247. # > If a class does not define an __eq__() method it should not define a
  248. # > __hash__() operation either
  249. __hash__ = None # type: ignore[assignment]
  250. def __eq__(self, rhs):
  251. rhs = self._as_tensor(rhs)
  252. return Tensor(self._symvar._binary_opr("EQ", rhs._symvar))
  253. def __ne__(self, rhs):
  254. return 1 - self.__eq__(rhs)
  255. def __len__(self):
  256. if self._symvar.eager_val is not None:
  257. return self._symvar.eager_val.shape[0]
  258. raise TypeError(
  259. "__len__ and __iter__ is not available for tensors on non eager graph."
  260. )
  261. __add__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__add__)
  262. __radd__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__radd__)
  263. __sub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__sub__)
  264. __rsub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rsub__)
  265. __mul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mul__)
  266. __rmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmul__)
  267. __matmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__matmul__)
  268. __rmatmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmatmul__)
  269. __lshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lshift__)
  270. __rshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rshift__)
  271. __truediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__truediv__)
  272. __rtruediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rtruediv__)
  273. __floordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__floordiv__)
  274. __rfloordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rfloordiv__)
  275. __mod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mod__)
  276. __rmod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmod__)
  277. __pow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__pow__)
  278. __rpow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rpow__)
  279. __lt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lt__)
  280. __gt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__gt__)
  281. __le__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__le__)
  282. __ge__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__ge__)
  283. __neg__ = wrap_io_tensor(mgb.SymbolVar.__neg__)
  284. sum = wrap_io_tensor(mgb.SymbolVar.sum)
  285. """
  286. Sum up the given tensors.
  287. """
  288. max = wrap_io_tensor(mgb.SymbolVar.max)
  289. """
  290. Return the maximum value of given tensor.
  291. """
  292. min = wrap_io_tensor(mgb.SymbolVar.min)
  293. """
  294. Return the minimum value of given tensor.
  295. """
  296. prod = wrap_io_tensor(mgb.SymbolVar.prod)
  297. """
  298. Return the product value of the given tensor.
  299. """
  300. mean = wrap_io_tensor(mgb.SymbolVar.mean)
  301. """
  302. Return the mean value of the given tensor.
  303. """
  304. dimshuffle = wrap_io_tensor(mgb.SymbolVar.dimshuffle)
  305. """
  306. See more details in :func:`~.functional.tensor.dimshuffle`.
  307. """
  308. astype = wrap_io_tensor(mgb.SymbolVar.astype)
  309. """
  310. Cast the tensor to a specified type.
  311. """
  312. def reshape(self, *target_shape):
  313. r"""Return a tensor which has given target shape
  314. Examples:
  315. .. testcode::
  316. import numpy as np
  317. from megengine import tensor
  318. inp = tensor(np.arange(1, 17, dtype=np.int32).reshape(4,4))
  319. out = tensor(np.arange(100, 116, dtype=np.int32).reshape(1,16))
  320. out = out.reshape(inp.shape)
  321. print(out.numpy())
  322. .. testoutput::
  323. [[100 101 102 103]
  324. [104 105 106 107]
  325. [108 109 110 111]
  326. [112 113 114 115]]
  327. """
  328. if isinstance(target_shape[0], tuple):
  329. if len(target_shape) > 1:
  330. raise ValueError("Only single tuple is accepted in reshape")
  331. target_shape = target_shape[0]
  332. target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape)
  333. return Tensor(mgb.SymbolVar.reshape(self._symvar, *target_shape))
  334. def broadcast(self, *target_shape):
  335. r"""Return a tesnor broadcasted by current tensor to given target shape
  336. Examples:
  337. .. testcode::
  338. import numpy as np
  339. from megengine import tensor
  340. data = tensor(np.arange(100, 104, dtype=np.int32).reshape(1,4))
  341. data = data.broadcast((4,4))
  342. print(data.numpy())
  343. .. testoutput::
  344. [[100 101 102 103]
  345. [100 101 102 103]
  346. [100 101 102 103]
  347. [100 101 102 103]]
  348. """
  349. if isinstance(target_shape[0], tuple):
  350. if len(target_shape) > 1:
  351. raise ValueError("Only single tuple is accepted in broadcast")
  352. target_shape = target_shape[0]
  353. target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape)
  354. return Tensor(mgb.SymbolVar.broadcast(self._symvar, *target_shape))
  355. # Prefer operators on Tensor instead of convert to numpy
  356. __array_priority__ = 1000
  357. # mgb indexing family
  358. def __getitem__(self, idx):
  359. return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx))
  360. def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
  361. r"""
  362. Return a object which supports using ``__getitem__`` to set subtensor.
  363. ``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``.
  364. """
  365. return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
  366. def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
  367. r"""
  368. Return a object which supports using ``__getitem__`` to increase subtensor.
  369. ``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``.
  370. """
  371. return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
  372. @property
  373. def ai(self) -> _MGBIndexWrapper:
  374. r"""
  375. Return a object which supports complex index method to get subtensor.
  376. Examples:
  377. .. testcode::
  378. from megengine import tensor
  379. a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
  380. print(a.ai[:, [2, 3]])
  381. Outputs:
  382. .. testoutput::
  383. Tensor([[ 2. 3.]
  384. [ 6. 7.]
  385. [10. 11.]
  386. [14. 15.]])
  387. """
  388. return _MGBIndexWrapper(self, mgb.opr.advanced_indexing)
  389. def set_ai(self, val: "Tensor") -> _MGBIndexWrapper:
  390. r"""
  391. Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
  392. """
  393. return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
  394. def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper:
  395. r"""
  396. Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
  397. """
  398. return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
  399. @property
  400. def mi(self) -> _MGBIndexWrapper:
  401. r"""
  402. Return a object which supports getting subtensor by
  403. the coordinates which is Cartesian product of given index.
  404. Examples:
  405. .. testcode::
  406. from megengine import tensor
  407. a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
  408. print(a.mi[[1, 2], [2, 3]])
  409. # is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]]
  410. # a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11
  411. Outputs:
  412. .. testoutput::
  413. Tensor([[ 6. 7.]
  414. [10. 11.]])
  415. """
  416. return _MGBIndexWrapper(self, mgb.opr.mesh_indexing)
  417. def set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  418. r"""
  419. Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
  420. """
  421. return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
  422. def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  423. r"""
  424. Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
  425. """
  426. return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
  427. @property
  428. def batched_mi(self) -> _MGBIndexWrapper:
  429. r"""
  430. Return a object which supports getting subtensor by
  431. batched mesh indexing.
  432. For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice.
  433. Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``.
  434. Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated.
  435. And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below.
  436. Examples:
  437. .. testcode::
  438. from megengine import tensor
  439. a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4)))
  440. print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]])
  441. # is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1])
  442. # and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1])
  443. # a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77
  444. print(a.batched_mi[:2, [[0],[1]], :2, :1])
  445. # is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]``
  446. Outputs:
  447. .. testoutput::
  448. Tensor([[[[ 0.]
  449. [ 4.]]]
  450. [[[73.]
  451. [77.]]]])
  452. Tensor([[[[ 0.]
  453. [ 4.]]]
  454. [[[64.]
  455. [68.]]]])
  456. """
  457. return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
  458. def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  459. r"""
  460. Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
  461. """
  462. return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
  463. def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  464. r"""
  465. Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
  466. """
  467. return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)
  468. def __array__(self, dtype=None):
  469. if dtype is None:
  470. return self.numpy()
  471. else:
  472. return self.numpy().astype(dtype, copy=False)
  473. def __int__(self):
  474. return int(self.item())
  475. def __index__(self):
  476. return int(self.item())
  477. def __round__(self, ndigits=0):
  478. if ndigits != 0:
  479. raise ValueError("ndigits must be 0 for Tensor.round")
  480. return Tensor(mgb.opr.elemwise([self._symvar], mode="ROUND"))
  481. round = __round__
  482. def sqrt(self):
  483. r"""Return a tensor that each element is the square root of its
  484. original value.
  485. """
  486. return Tensor(mgb.opr.sqrt(self._symvar))
  487. def shapeof(self, axis=None):
  488. r"""Return a Tensor that represent the shape of the tensor.
  489. """
  490. return Tensor(mgb.opr.get_var_shape(self._symvar, axis=axis))
  491. @property
  492. def ndim(self):
  493. r"""Return the number of dimensions of the tensor.
  494. """
  495. return len(self._symvar.imm_shape)
  496. def __repr__(self):
  497. piece = "Tensor("
  498. with np.printoptions(precision=4, suppress=True):
  499. piece += "{}".format(str(self.numpy()))
  500. if self.dtype != np.float32:
  501. piece += ", dtype={}".format(np.dtype(self.dtype).name)
  502. if self._comp_node.locator_logical != ("XPU", -1, 0):
  503. piece += ", device={}".format(self.device)
  504. piece += ")"
  505. return piece
  506. def __bool__(self):
  507. raise RuntimeError(
  508. "Tensor object should not be converted to bool or used in a if statement. Use .numpy(), int() or float() if you want to use its value in if statement, be aware that this may lead to incorrect result in non-eager mode."
  509. )
  510. def __getstate__(self):
  511. r""" __getstate__ will be called for pickle serialization or deep copy
  512. """
  513. assert (self.__val is not None) and (
  514. self.__sym is None
  515. ), "Only SharedND initialized Tensor can be serialized or deep copied"
  516. metadata = {"requires_grad": self.requires_grad}
  517. state = {
  518. "data": self.numpy(),
  519. "device": self.device,
  520. "dtype": self.dtype,
  521. "metadata": metadata,
  522. }
  523. return state
  524. def __setstate__(self, state):
  525. data = state.pop("data")
  526. device = state.pop("device")
  527. dtype = state.pop("dtype")
  528. metadata = state.pop("metadata", {})
  529. requires_grad = metadata.pop("requires_grad", None)
  530. snd = mgb.make_shared(device, value=data, dtype=dtype)
  531. self._reset(snd, requires_grad=requires_grad)
  532. def tensor(
  533. data: Union[list, np.ndarray] = None,
  534. *,
  535. dtype: str = None,
  536. device: mgb.CompNode = None,
  537. requires_grad: bool = None
  538. ):
  539. r"""A helper function to create a :class:`~.Tensor` using existing data.
  540. :param data: an existing data array, must be Python list, NumPy array or None.
  541. :param dtype: target Tensor data type, one of ``("uint8", "int8", "int16", "int32", "float32", "float16")``.
  542. :param device: target device for Tensor storing.
  543. :param requires_grad: whether its gradiant will be calculated during :meth:`~.Optimizer.backward`
  544. """
  545. supported_dtypes = ("uint8", "int8", "int16", "int32", "float32", "float16")
  546. if isinstance(data, Tensor):
  547. raise NotImplementedError
  548. if dtype is not None and np.dtype(dtype).name not in supported_dtypes:
  549. raise TypeError("unsupported dtype {}".format(dtype))
  550. if data is not None:
  551. if not isinstance(data, np.ndarray):
  552. data = np.array(data, dtype=dtype)
  553. # In order to accept tensor([1]),
  554. # Automaticlly convert to 32-bit number instead of numpy's default 64-bit when input data is not nparray.
  555. dtype = mgb.to_mgb_supported_dtype(data.dtype)
  556. if dtype is None:
  557. if data.dtype.name not in supported_dtypes:
  558. raise TypeError("unsupported dtype {}".format(data.dtype))
  559. device, _ = _use_default_if_none(device, None)
  560. shared_nd = mgb.make_shared(device, value=data, dtype=dtype)
  561. return Tensor(shared_nd, requires_grad=requires_grad)
  562. class TensorDict(collections.MutableMapping):
  563. r"""
  564. A helper class to maintain dict with Tensor key.
  565. """
  566. def __init__(self, *args, **kwargs):
  567. self.data = {}
  568. for i in args:
  569. self.update(i)
  570. self.update(**kwargs)
  571. class keyfn:
  572. def __new__(cls, x: Tensor):
  573. if not isinstance(x, Tensor):
  574. return x
  575. return super().__new__(cls)
  576. def __init__(self, x: Tensor):
  577. self._data = x # do not save id directly to make pickle work
  578. def __hash__(self):
  579. return id(self._data)
  580. def __eq__(self, other):
  581. return isinstance(other, type(self)) and id(self._data) == id(other._data)
  582. def __getitem__(self, key):
  583. _, v = self.data[self.keyfn(key)]
  584. return v
  585. def __setitem__(self, key, value):
  586. self.data[self.keyfn(key)] = key, value
  587. def __delitem__(self, key):
  588. del self.data[self.keyfn(key)]
  589. def __iter__(self):
  590. for _, (k, _) in self.data.items():
  591. yield k
  592. def __len__(self):
  593. return len(self.data)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台