You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import functools
  11. import itertools
  12. import weakref
  13. from typing import Union
  14. import numpy as np
  15. import megengine._internal as mgb
  16. from .graph import _use_default_if_none, get_default_graph
  17. def wrap_io_tensor(func):
  18. r"""A wrapper to make ``func`` compatible with functions in ``_internal.opr``.
  19. """
  20. @functools.wraps(func)
  21. def wrapper(*args, **kwargs):
  22. comp_graph = None
  23. for i in itertools.chain(args, kwargs.values()):
  24. if isinstance(i, Tensor) and i._comp_graph:
  25. comp_graph = i._comp_graph
  26. break
  27. else:
  28. comp_graph = get_default_graph()
  29. new_args = (
  30. arg._attach(comp_graph) if isinstance(arg, Tensor) else arg for arg in args
  31. )
  32. new_kwargs = {
  33. k: v._attach(comp_graph) if isinstance(v, Tensor) else v
  34. for k, v in kwargs.items()
  35. }
  36. ret = func(*new_args, **new_kwargs)
  37. if isinstance(ret, mgb.SymbolVar):
  38. ret = Tensor(ret)
  39. elif isinstance(ret, list):
  40. ret = [Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret]
  41. elif isinstance(ret, tuple):
  42. ret = tuple(Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret)
  43. return ret
  44. return wrapper
  45. def _wrap_symbolvar_binary_op(f):
  46. @functools.wraps(f)
  47. def wrapped(self, other):
  48. comp_graph = (
  49. isinstance(other, Tensor)
  50. and other._comp_graph
  51. or self._comp_graph
  52. or get_default_graph()
  53. )
  54. if isinstance(other, Tensor):
  55. other = other._attach(comp_graph)
  56. return Tensor(f(self._attach(comp_graph), other))
  57. return wrapped
  58. def wrap_slice(inp):
  59. start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start
  60. stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop
  61. step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step
  62. return slice(start, stop, step)
  63. def wrap_idx(idx):
  64. if not isinstance(idx, tuple):
  65. idx = (idx,)
  66. idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx)
  67. idx = tuple(wrap_slice(i) if isinstance(i, slice) else i for i in idx)
  68. return idx
  69. class MGBIndexWrapper:
  70. def __init__(self, dest, mgb_index, val=None):
  71. self.dest = dest
  72. self.val = val
  73. self.mgb_index = mgb_index
  74. def __getitem__(self, idx):
  75. if self.val is None:
  76. return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)(
  77. wrap_idx(idx)
  78. )
  79. else:
  80. return wrap_io_tensor(
  81. self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__
  82. )(wrap_idx(idx))
  83. class Guard:
  84. def __init__(self, deleter):
  85. self.deleter = deleter
  86. def __del__(self):
  87. self.deleter()
  88. class Tensor:
  89. r"""The main data container in MegEngine.
  90. Use :func:`~.tensor` to create a Tensor with existed data.
  91. """
  92. requires_grad = False
  93. grad = None
  94. def __init__(self, val=None, *, requires_grad=None):
  95. self._reset(val, requires_grad=requires_grad)
  96. def _reset(self, val=None, *, requires_grad=None):
  97. self.__sym_override = None
  98. if val is None:
  99. self.__val = None
  100. self.__sym = None
  101. elif isinstance(val, mgb.SharedND):
  102. self.__val = val
  103. self.__sym = None
  104. elif isinstance(val, mgb.SymbolVar):
  105. self.__val = None
  106. self.__sym = val
  107. else:
  108. raise TypeError("must be initialized with SymbolVar or SharedND")
  109. self.requires_grad = requires_grad
  110. def _as_tensor(self, obj):
  111. r"""Convert the data into a ``Tensor``. If the data is already a Tensor
  112. with the same dtype and device, no copy will be performed. Otherwise a
  113. new Tensor will be returned with computational graph retained.
  114. """
  115. if isinstance(obj, Tensor):
  116. return obj
  117. if isinstance(obj, mgb.SymbolVar):
  118. return Tensor(obj)
  119. if isinstance(obj, mgb.SharedScalar):
  120. return Tensor(obj._as_sym_var(self._comp_graph, self._comp_node))
  121. return tensor(data=obj, device=self.device)
  122. def numpy(self):
  123. r"""Return the tensor value in numpy.ndarray format.
  124. """
  125. if self.__val is not None:
  126. assert self.__sym is None
  127. return self.__val.get_value()
  128. if self.__sym is None:
  129. raise ValueError("uninitialized")
  130. if self.__sym.eager_val is not None:
  131. return self.__sym.eager_val.get_value()
  132. return self.__sym.inferred_value
  133. def item(self):
  134. return self.numpy().item()
  135. def _attach(self, comp_graph, *, volatile=True):
  136. sym = self.__sym_override or self.__sym
  137. if sym:
  138. if sym.owner_graph != comp_graph:
  139. raise RuntimeError("internal error")
  140. return sym
  141. if self.__val:
  142. return self.__val.symvar(comp_graph, volatile=volatile)
  143. else:
  144. raise ValueError("uninitialized")
  145. @property
  146. def _symvar(self):
  147. if self.__sym_override:
  148. return self.__sym_override
  149. if self.__sym:
  150. assert not self.__val
  151. return self.__sym
  152. if not self.__val:
  153. raise ValueError("uninitialized")
  154. return self._attach(get_default_graph())
  155. def __mgb_symvar__(self, comp_graph=None, **_):
  156. if self.__sym_override:
  157. return self.__sym_override
  158. if self.__val and comp_graph:
  159. return self._attach(comp_graph)
  160. return self._symvar # read by mgb.opr
  161. def _override_symvar_during_trace(self, trace, symvar):
  162. assert self.__val and not self.__sym
  163. assert trace is type(trace)._active_instance
  164. deleters = trace._user_cache.setdefault(Tensor, set())
  165. self_ref = weakref.ref(self)
  166. def restore():
  167. self = self_ref()
  168. if self is not None:
  169. self.__sym_override = None
  170. deleters.add(Guard(restore))
  171. self.__sym_override = symvar
  172. @property
  173. def dtype(self):
  174. r"""Return the data type of the tensor.
  175. """
  176. if self.__val is not None:
  177. return self.__val.dtype
  178. return self._symvar.dtype
  179. @property
  180. def _comp_node(self):
  181. if self.__val is not None:
  182. return self.__val.comp_node
  183. return self._symvar.comp_node
  184. device = _comp_node
  185. @property
  186. def _comp_graph(self):
  187. if self.__sym is not None:
  188. return self.__sym.owner_graph
  189. return None
  190. @property
  191. def shape(self):
  192. r"""Return an int tuple that is the shape/layout of the tensor.
  193. Could be invalid in static graph mode.
  194. """
  195. from ..jit import trace
  196. if trace._active_instance: # pylint: disable=protected-access
  197. # NOTE: this is an hack
  198. shape = mgb.opr.get_var_shape(self._symvar)
  199. return tuple(Tensor(shape[i]) for i in range(self.ndim))
  200. return self._symvar.imm_shape
  201. def set_value(self, value, *, sync=True, inplace=False, share=False):
  202. r"""Set value to the tensor.
  203. """
  204. if not self.__val:
  205. raise ValueError("not detached")
  206. if isinstance(value, Tensor):
  207. value = value.__val or value.__sym.eager_val
  208. self.__val.set_value(value, sync=sync, inplace=inplace, share=share)
  209. def fill(self, value):
  210. r"""Fills the tensor with the specified value.
  211. """
  212. self.set_value(np.full(self.shape, value, dtype=self.dtype))
  213. def reset_zero(self):
  214. r"""Reset the tensor and fills with zeros.
  215. """
  216. if not self.__val:
  217. raise ValueError("not detached")
  218. self.__val.reset_zero()
  219. def to(self, device):
  220. r"""Performs Tensor device conversion, returns Tensor with the specified device.
  221. """
  222. return wrap_io_tensor(mgb.opr.copy)(self, comp_node=device)
  223. # https://docs.python.org/3/reference/datamodel.html#object.__hash__
  224. # > If a class does not define an __eq__() method it should not define a
  225. # > __hash__() operation either
  226. __hash__ = None # type: ignore[assignment]
  227. def __eq__(self, rhs):
  228. rhs = self._as_tensor(rhs)
  229. return Tensor(self._symvar._binary_opr("EQ", rhs._symvar))
  230. def __ne__(self, rhs):
  231. return 1 - self.__eq__(rhs)
  232. def __len__(self):
  233. if self._symvar.eager_val is not None:
  234. return self._symvar.eager_val.shape[0]
  235. raise TypeError(
  236. "__len__ and __iter__ is not available for tensors on non eager graph."
  237. )
  238. __add__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__add__)
  239. __radd__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__radd__)
  240. __sub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__sub__)
  241. __rsub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rsub__)
  242. __mul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mul__)
  243. __rmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmul__)
  244. __matmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__matmul__)
  245. __rmatmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmatmul__)
  246. __lshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lshift__)
  247. __rshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rshift__)
  248. __truediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__truediv__)
  249. __rtruediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rtruediv__)
  250. __floordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__floordiv__)
  251. __rfloordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rfloordiv__)
  252. __mod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mod__)
  253. __rmod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmod__)
  254. __pow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__pow__)
  255. __rpow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rpow__)
  256. __lt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lt__)
  257. __gt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__gt__)
  258. __le__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__le__)
  259. __ge__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__ge__)
  260. __neg__ = wrap_io_tensor(mgb.SymbolVar.__neg__)
  261. sum = wrap_io_tensor(mgb.SymbolVar.sum)
  262. """
  263. Sum up the given tensors.
  264. """
  265. max = wrap_io_tensor(mgb.SymbolVar.max)
  266. """
  267. Return the maximum value of given tensor.
  268. """
  269. min = wrap_io_tensor(mgb.SymbolVar.min)
  270. """
  271. Return the minimum value of given tensor.
  272. """
  273. prod = wrap_io_tensor(mgb.SymbolVar.prod)
  274. """
  275. Return the product value of the given tensor.
  276. """
  277. mean = wrap_io_tensor(mgb.SymbolVar.mean)
  278. """
  279. Return the mean value of the given tensor.
  280. """
  281. dimshuffle = wrap_io_tensor(mgb.SymbolVar.dimshuffle)
  282. """
  283. See more details in :func:`~.functional.tensor.dimshuffle`.
  284. """
  285. astype = wrap_io_tensor(mgb.SymbolVar.astype)
  286. """
  287. Cast the tensor to a specified type.
  288. """
  289. def reshape(self, *target_shape):
  290. r"""Return a tensor which has given target shape
  291. Examples:
  292. .. testcode::
  293. import numpy as np
  294. from megengine import tensor
  295. inp = tensor(np.arange(1, 17, dtype=np.int32).reshape(4,4))
  296. out = tensor(np.arange(100, 116, dtype=np.int32).reshape(1,16))
  297. out = out.reshape(inp.shape)
  298. print(out.numpy())
  299. .. testoutput::
  300. [[100 101 102 103]
  301. [104 105 106 107]
  302. [108 109 110 111]
  303. [112 113 114 115]]
  304. """
  305. if isinstance(target_shape[0], tuple):
  306. if len(target_shape) > 1:
  307. raise ValueError("Only single tuple is accepted in reshape")
  308. target_shape = target_shape[0]
  309. target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape)
  310. return Tensor(mgb.SymbolVar.reshape(self._symvar, *target_shape))
  311. def broadcast(self, *target_shape):
  312. r"""Return a tesnor broadcasted by current tensor to given target shape
  313. Examples:
  314. .. testcode::
  315. import numpy as np
  316. from megengine import tensor
  317. data = tensor(np.arange(100, 104, dtype=np.int32).reshape(1,4))
  318. data = data.broadcast((4,4))
  319. print(data.numpy())
  320. .. testoutput::
  321. [[100 101 102 103]
  322. [100 101 102 103]
  323. [100 101 102 103]
  324. [100 101 102 103]]
  325. """
  326. if isinstance(target_shape[0], tuple):
  327. if len(target_shape) > 1:
  328. raise ValueError("Only single tuple is accepted in broadcast")
  329. target_shape = target_shape[0]
  330. target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape)
  331. return Tensor(mgb.SymbolVar.broadcast(self._symvar, *target_shape))
  332. # Prefer operators on Tensor instead of convert to numpy
  333. __array_priority__ = 1000
  334. # mgb indexing family
  335. def __getitem__(self, idx):
  336. return wrap_io_tensor(self._symvar.__getitem__)(wrap_idx(idx))
  337. def set_subtensor(self, val):
  338. return MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
  339. def incr_subtensor(self, val):
  340. return MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
  341. @property
  342. def ai(self):
  343. return MGBIndexWrapper(self, mgb.opr.advanced_indexing)
  344. def set_ai(self, val):
  345. return MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
  346. def incr_ai(self, val):
  347. return MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
  348. @property
  349. def mi(self):
  350. return MGBIndexWrapper(self, mgb.opr.mesh_indexing)
  351. def set_mi(self, val):
  352. return MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
  353. def incr_mi(self, val):
  354. return MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
  355. @property
  356. def batched_mi(self):
  357. return MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
  358. def batched_set_mi(self, val):
  359. return MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
  360. def batched_incr_mi(self, val):
  361. return MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)
  362. def __array__(self, dtype=None):
  363. if dtype is None:
  364. return self.numpy()
  365. else:
  366. return self.numpy().astype(dtype, copy=False)
  367. def __int__(self):
  368. return int(self.item())
  369. def __index__(self):
  370. return int(self.item())
  371. def __round__(self, ndigits=0):
  372. if ndigits != 0:
  373. raise ValueError("ndigits must be 0 for Tensor.round")
  374. return Tensor(mgb.opr.elemwise([self._symvar], mode="ROUND"))
  375. round = __round__
  376. def sqrt(self):
  377. r"""Return a tensor that each element is the square root of its
  378. original value.
  379. """
  380. return Tensor(mgb.opr.sqrt(self._symvar))
  381. def shapeof(self, axis=None):
  382. r"""Return a Tensor that represent the shape of the tensor.
  383. """
  384. return Tensor(mgb.opr.get_var_shape(self._symvar, axis=axis))
  385. @property
  386. def ndim(self):
  387. r"""Return the number of dimensions of the tensor.
  388. """
  389. return len(self._symvar.imm_shape)
  390. def __repr__(self):
  391. piece = "Tensor("
  392. with np.printoptions(precision=4, suppress=True):
  393. piece += "{}".format(str(self.numpy()))
  394. if self.dtype != np.float32:
  395. piece += ", dtype={}".format(np.dtype(self.dtype).name)
  396. if self._comp_node.locator_logical != ("XPU", -1, 0):
  397. piece += ", device={}".format(self.device)
  398. piece += ")"
  399. return piece
  400. def __bool__(self):
  401. raise RuntimeError(
  402. "Tensor object should not be converted to bool or used in a if statement. Use .numpy(), int() or float() if you want to use its value in if statement, be aware that this may lead to incorrect result in non-eager mode."
  403. )
  404. def __getstate__(self):
  405. assert (self.__val is not None) and (self.__sym is None)
  406. metadata = {"requires_grad": self.requires_grad}
  407. state = {
  408. "data": self.numpy(),
  409. "device": self.device,
  410. "dtype": self.dtype,
  411. "metadata": metadata,
  412. }
  413. return state
  414. def __setstate__(self, state):
  415. data = state.pop("data")
  416. device = state.pop("device")
  417. dtype = state.pop("dtype")
  418. metadata = state.pop("metadata", {})
  419. requires_grad = metadata.pop("requires_grad", None)
  420. snd = mgb.make_shared(device, value=data, dtype=dtype)
  421. self._reset(snd, requires_grad=requires_grad)
  422. def tensor(
  423. data: Union[list, np.ndarray] = None,
  424. *,
  425. dtype: str = None,
  426. device: mgb.CompNode = None,
  427. requires_grad: bool = None
  428. ):
  429. r"""A helper function to create a :class:`~.Tensor` using existing data.
  430. :param data: an existing data array, must be Python list, NumPy array or None.
  431. :param dtype: target Tensor data type, one of ``("uint8", "int8", "int16", "int32", "float32", "float16")``.
  432. :param device: target device for Tensor storing.
  433. :param requires_grad: whether its gradiant will be calculated during :meth:`~.Optimizer.backward`
  434. """
  435. supported_dtypes = ("uint8", "int8", "int16", "int32", "float32", "float16")
  436. if isinstance(data, Tensor):
  437. raise NotImplementedError
  438. if dtype is not None and np.dtype(dtype).name not in supported_dtypes:
  439. raise TypeError("unsupported dtype {}".format(dtype))
  440. if data is not None:
  441. if not isinstance(data, np.ndarray):
  442. data = np.array(data, dtype=dtype)
  443. # In order to accept tensor([1]),
  444. # Automaticlly convert to 32-bit number instead of numpy's default 64-bit when input data is not nparray.
  445. dtype = mgb.to_mgb_supported_dtype(data.dtype)
  446. if dtype is None:
  447. if data.dtype.name not in supported_dtypes:
  448. raise TypeError("unsupported dtype {}".format(data.dtype))
  449. device, _ = _use_default_if_none(device, None)
  450. shared_nd = mgb.make_shared(device, value=data, dtype=dtype)
  451. return Tensor(shared_nd, requires_grad=requires_grad)
  452. class Dict(collections.MutableMapping):
  453. def __init__(self, *args, key=None, **kwargs):
  454. self.data = {}
  455. if key:
  456. self.keyfn = key
  457. for i in args:
  458. self.update(i)
  459. self.update(**kwargs)
  460. @staticmethod
  461. def keyfn(key): # pylint: disable=method-hidden
  462. return key
  463. def __getitem__(self, key):
  464. _, v = self.data[self.keyfn(key)]
  465. return v
  466. def __setitem__(self, key, value):
  467. self.data[self.keyfn(key)] = key, value
  468. def __delitem__(self, key):
  469. del self.data[self.keyfn(key)]
  470. def __iter__(self):
  471. for _, (k, _) in self.data.items():
  472. yield k
  473. def __len__(self):
  474. return len(self.data)
  475. class TensorDict(Dict): # pylint: disable=too-many-ancestors
  476. class keyfn:
  477. def __new__(cls, x: Tensor):
  478. if not isinstance(x, Tensor):
  479. return x
  480. return super().__new__(cls)
  481. def __init__(self, x: Tensor):
  482. self._data = x # do not save id directly to make pickle work
  483. def __hash__(self):
  484. return id(self._data)
  485. def __eq__(self, other):
  486. # pylint: disable=undefined-variable
  487. return isinstance(other, __class__) and id(self._data) == id(other._data)
  488. def __init__(self, *args):
  489. super().__init__(*args)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台