You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import functools
  11. import itertools
  12. from typing import Union
  13. import numpy as np
  14. import megengine._internal as mgb
  15. from .graph import _use_default_if_none, get_default_graph
  16. def wrap_io_tensor(func):
  17. r"""A wrapper to make ``func`` compatible with functions in ``_internal.opr``.
  18. """
  19. @functools.wraps(func)
  20. def wrapper(*args, **kwargs):
  21. comp_graph = None
  22. for i in itertools.chain(args, kwargs.values()):
  23. if isinstance(i, Tensor) and i._comp_graph:
  24. comp_graph = i._comp_graph
  25. break
  26. else:
  27. comp_graph = get_default_graph()
  28. new_args = (
  29. arg._attach(comp_graph) if isinstance(arg, Tensor) else arg for arg in args
  30. )
  31. new_kwargs = {
  32. k: v._attach(comp_graph) if isinstance(v, Tensor) else v
  33. for k, v in kwargs.items()
  34. }
  35. ret = func(*new_args, **new_kwargs)
  36. if isinstance(ret, mgb.SymbolVar):
  37. ret = Tensor(ret)
  38. elif isinstance(ret, list):
  39. ret = [Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret]
  40. elif isinstance(ret, tuple):
  41. ret = tuple(Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret)
  42. return ret
  43. return wrapper
  44. def _wrap_symbolvar_binary_op(f):
  45. @functools.wraps(f)
  46. def wrapped(self, other):
  47. comp_graph = (
  48. isinstance(other, Tensor)
  49. and other._comp_graph
  50. or self._comp_graph
  51. or get_default_graph()
  52. )
  53. if isinstance(other, Tensor):
  54. other = other._attach(comp_graph)
  55. return Tensor(f(self._attach(comp_graph), other))
  56. return wrapped
  57. def wrap_slice(inp):
  58. start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start
  59. stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop
  60. step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step
  61. return slice(start, stop, step)
  62. def wrap_idx(idx):
  63. if not isinstance(idx, tuple):
  64. idx = (idx,)
  65. idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx)
  66. idx = tuple(wrap_slice(i) if isinstance(i, slice) else i for i in idx)
  67. return idx
  68. class MGBIndexWrapper:
  69. def __init__(self, dest, mgb_index, val=None):
  70. self.dest = dest
  71. self.val = val
  72. self.mgb_index = mgb_index
  73. def __getitem__(self, idx):
  74. if self.val is None:
  75. return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)(
  76. wrap_idx(idx)
  77. )
  78. else:
  79. return wrap_io_tensor(
  80. self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__
  81. )(wrap_idx(idx))
  82. class Tensor:
  83. r"""The main data container in MegEngine.
  84. Use :func:`~.tensor` to create a Tensor with existed data.
  85. """
  86. requires_grad = False
  87. grad = None
  88. def __init__(self, val=None, *, requires_grad=None):
  89. self._reset(val, requires_grad=requires_grad)
  90. def _reset(self, val=None, *, requires_grad=None):
  91. if val is None:
  92. self.__val = None
  93. self.__sym = None
  94. elif isinstance(val, mgb.SharedND):
  95. self.__val = val
  96. self.__sym = None
  97. elif isinstance(val, mgb.SymbolVar):
  98. self.__val = None
  99. self.__sym = val
  100. else:
  101. raise TypeError("must be initialized with SymbolVar or SharedND")
  102. self.requires_grad = requires_grad
  103. def _as_tensor(self, obj):
  104. r"""Convert the data into a ``Tensor``. If the data is already a Tensor
  105. with the same dtype and device, no copy will be performed. Otherwise a
  106. new Tensor will be returned with computational graph retained.
  107. """
  108. if isinstance(obj, Tensor):
  109. return obj
  110. if isinstance(obj, mgb.SymbolVar):
  111. return Tensor(obj)
  112. if isinstance(obj, mgb.SharedScalar):
  113. return Tensor(obj._as_sym_var(self._comp_graph, self._comp_node))
  114. return tensor(data=obj, device=self.device)
  115. def numpy(self):
  116. r"""Return the tensor value in numpy.ndarray format.
  117. """
  118. if self.__val is not None:
  119. assert self.__sym is None
  120. return self.__val.get_value()
  121. if self.__sym is None:
  122. raise ValueError("uninitialized")
  123. if self.__sym.eager_val is not None:
  124. return self.__sym.eager_val.get_value()
  125. return self.__sym.inferred_value
  126. def item(self):
  127. return self.numpy().item()
  128. def _attach(self, comp_graph, *, volatile=True):
  129. if self.__val:
  130. return self.__val.symvar(comp_graph, volatile=volatile)
  131. if self.__sym:
  132. if self.__sym.owner_graph != comp_graph:
  133. raise RuntimeError("internal error")
  134. return self.__sym
  135. else:
  136. raise ValueError("uninitialized")
  137. @property
  138. def _symvar(self):
  139. if self.__sym:
  140. assert not self.__val
  141. return self.__sym
  142. if not self.__val:
  143. raise ValueError("uninitialized")
  144. return self._attach(get_default_graph())
  145. def __mgb_symvar__(self, comp_graph=None, **_):
  146. if self.__val and comp_graph:
  147. return self._attach(comp_graph)
  148. return self._symvar # read by mgb.opr
  149. @property
  150. def dtype(self):
  151. r"""Return the data type of the tensor.
  152. """
  153. if self.__val is not None:
  154. return self.__val.dtype
  155. return self._symvar.dtype
  156. @property
  157. def _comp_node(self):
  158. if self.__val is not None:
  159. return self.__val.comp_node
  160. return self._symvar.comp_node
  161. device = _comp_node
  162. @property
  163. def _comp_graph(self):
  164. if self.__sym is not None:
  165. return self.__sym.owner_graph
  166. return None
  167. @property
  168. def shape(self):
  169. r"""Return an int tuple that is the shape/layout of the tensor.
  170. Could be invalid in static graph mode.
  171. """
  172. from ..jit import trace
  173. if trace._active_instance: # pylint: disable=protected-access
  174. # NOTE: this is an hack
  175. shape = mgb.opr.get_var_shape(self._symvar)
  176. return tuple(Tensor(shape[i]) for i in range(self.ndim))
  177. return self._symvar.imm_shape
  178. def set_value(self, value, *, sync=True, inplace=False, share=False):
  179. r"""Set value to the tensor.
  180. """
  181. if not self.__val:
  182. raise ValueError("not detached")
  183. if isinstance(value, Tensor):
  184. value = value.__val or value.__sym.eager_val
  185. self.__val.set_value(value, sync=sync, inplace=inplace, share=share)
  186. def fill(self, value):
  187. r"""Fills the tensor with the specified value.
  188. """
  189. self.set_value(np.full(self.shape, value, dtype=self.dtype))
  190. def reset_zero(self):
  191. r"""Reset the tensor and fills with zeros.
  192. """
  193. if not self.__val:
  194. raise ValueError("not detached")
  195. self.__val.reset_zero()
  196. def to(self, device):
  197. r"""Performs Tensor device conversion, returns Tensor with the specified device.
  198. """
  199. return wrap_io_tensor(mgb.opr.copy)(self, comp_node=device)
  200. # https://docs.python.org/3/reference/datamodel.html#object.__hash__
  201. # > If a class does not define an __eq__() method it should not define a
  202. # > __hash__() operation either
  203. __hash__ = None # type: ignore[assignment]
  204. def __eq__(self, rhs):
  205. rhs = self._as_tensor(rhs)
  206. return Tensor(self._symvar._binary_opr("EQ", rhs._symvar))
  207. def __ne__(self, rhs):
  208. return 1 - self.__eq__(rhs)
  209. def __len__(self):
  210. if self._symvar.eager_val is not None:
  211. return self._symvar.eager_val.shape[0]
  212. raise TypeError(
  213. "__len__ and __iter__ is not available for tensors on non eager graph."
  214. )
  215. __add__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__add__)
  216. __radd__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__radd__)
  217. __sub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__sub__)
  218. __rsub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rsub__)
  219. __mul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mul__)
  220. __rmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmul__)
  221. __matmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__matmul__)
  222. __rmatmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmatmul__)
  223. __lshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lshift__)
  224. __rshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rshift__)
  225. __truediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__truediv__)
  226. __rtruediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rtruediv__)
  227. __floordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__floordiv__)
  228. __rfloordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rfloordiv__)
  229. __mod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mod__)
  230. __rmod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmod__)
  231. __pow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__pow__)
  232. __rpow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rpow__)
  233. __lt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lt__)
  234. __gt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__gt__)
  235. __le__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__le__)
  236. __ge__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__ge__)
  237. __neg__ = wrap_io_tensor(mgb.SymbolVar.__neg__)
  238. sum = wrap_io_tensor(mgb.SymbolVar.sum)
  239. """
  240. Sum up the given tensors.
  241. """
  242. max = wrap_io_tensor(mgb.SymbolVar.max)
  243. """
  244. Return the maximum value of given tensor.
  245. """
  246. min = wrap_io_tensor(mgb.SymbolVar.min)
  247. """
  248. Return the minimum value of given tensor.
  249. """
  250. prod = wrap_io_tensor(mgb.SymbolVar.prod)
  251. """
  252. Return the product value of the given tensor.
  253. """
  254. mean = wrap_io_tensor(mgb.SymbolVar.mean)
  255. """
  256. Return the mean value of the given tensor.
  257. """
  258. dimshuffle = wrap_io_tensor(mgb.SymbolVar.dimshuffle)
  259. """
  260. See more details in :func:`~.functional.tensor.dimshuffle`.
  261. """
  262. astype = wrap_io_tensor(mgb.SymbolVar.astype)
  263. """
  264. Cast the tensor to a specified type.
  265. """
  266. def reshape(self, *target_shape):
  267. r"""Return a tensor which has given target shape
  268. Examples:
  269. .. testcode::
  270. import numpy as np
  271. from megengine import tensor
  272. inp = tensor(np.arange(1, 17, dtype=np.int32).reshape(4,4))
  273. out = tensor(np.arange(100, 116, dtype=np.int32).reshape(1,16))
  274. out = out.reshape(inp.shape)
  275. print(out.numpy())
  276. .. testoutput::
  277. [[100 101 102 103]
  278. [104 105 106 107]
  279. [108 109 110 111]
  280. [112 113 114 115]]
  281. """
  282. if isinstance(target_shape[0], tuple):
  283. if len(target_shape) > 1:
  284. raise ValueError("Only single tuple is accepted in reshape")
  285. target_shape = target_shape[0]
  286. target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape)
  287. return Tensor(mgb.SymbolVar.reshape(self._symvar, *target_shape))
  288. def broadcast(self, *target_shape):
  289. r"""Return a tesnor broadcasted by current tensor to given target shape
  290. Examples:
  291. .. testcode::
  292. import numpy as np
  293. from megengine import tensor
  294. data = tensor(np.arange(100, 104, dtype=np.int32).reshape(1,4))
  295. data = data.broadcast((4,4))
  296. print(data.numpy())
  297. .. testoutput::
  298. [[100 101 102 103]
  299. [100 101 102 103]
  300. [100 101 102 103]
  301. [100 101 102 103]]
  302. """
  303. if isinstance(target_shape[0], tuple):
  304. if len(target_shape) > 1:
  305. raise ValueError("Only single tuple is accepted in broadcast")
  306. target_shape = target_shape[0]
  307. target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape)
  308. return Tensor(mgb.SymbolVar.broadcast(self._symvar, *target_shape))
  309. # Prefer operators on Tensor instead of convert to numpy
  310. __array_priority__ = 1000
  311. # mgb indexing family
  312. def __getitem__(self, idx):
  313. return wrap_io_tensor(self._symvar.__getitem__)(wrap_idx(idx))
  314. def set_subtensor(self, val):
  315. return MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
  316. def incr_subtensor(self, val):
  317. return MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
  318. @property
  319. def ai(self):
  320. return MGBIndexWrapper(self, mgb.opr.advanced_indexing)
  321. def set_ai(self, val):
  322. return MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
  323. def incr_ai(self, val):
  324. return MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
  325. @property
  326. def mi(self):
  327. return MGBIndexWrapper(self, mgb.opr.mesh_indexing)
  328. def set_mi(self, val):
  329. return MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
  330. def incr_mi(self, val):
  331. return MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
  332. @property
  333. def batched_mi(self):
  334. return MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
  335. def batched_set_mi(self, val):
  336. return MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
  337. def batched_incr_mi(self, val):
  338. return MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)
  339. def __array__(self, dtype=None):
  340. if dtype is None:
  341. return self.numpy()
  342. else:
  343. return self.numpy().astype(dtype, copy=False)
  344. def __int__(self):
  345. return int(self.item())
  346. def __index__(self):
  347. return int(self.item())
  348. def __round__(self, ndigits=0):
  349. if ndigits != 0:
  350. raise ValueError("ndigits must be 0 for Tensor.round")
  351. return Tensor(mgb.opr.elemwise([self._symvar], mode="ROUND"))
  352. round = __round__
  353. def sqrt(self):
  354. r"""Return a tensor that each element is the square root of its
  355. original value.
  356. """
  357. return Tensor(mgb.opr.sqrt(self._symvar))
  358. def shapeof(self, axis=None):
  359. r"""Return a Tensor that represent the shape of the tensor.
  360. """
  361. return Tensor(mgb.opr.get_var_shape(self._symvar, axis=axis))
  362. @property
  363. def ndim(self):
  364. r"""Return the number of dimensions of the tensor.
  365. """
  366. return len(self._symvar.imm_shape)
  367. def __repr__(self):
  368. piece = "Tensor("
  369. with np.printoptions(precision=4, suppress=True):
  370. piece += "{}".format(str(self.numpy()))
  371. if self.dtype != np.float32:
  372. piece += ", dtype={}".format(np.dtype(self.dtype).name)
  373. if self._comp_node.locator_logical != ("XPU", -1, 0):
  374. piece += ", device={}".format(self.device)
  375. piece += ")"
  376. return piece
  377. def __bool__(self):
  378. raise RuntimeError(
  379. "Tensor object should not be converted to bool or used in a if statement. Use .numpy(), int() or float() if you want to use its value in if statement, be aware that this may lead to incorrect result in non-eager mode."
  380. )
  381. def __getstate__(self):
  382. assert (self.__val is not None) and (self.__sym is None)
  383. metadata = {"requires_grad": self.requires_grad}
  384. state = {
  385. "data": self.numpy(),
  386. "device": self.device,
  387. "dtype": self.dtype,
  388. "metadata": metadata,
  389. }
  390. return state
  391. def __setstate__(self, state):
  392. data = state.pop("data")
  393. device = state.pop("device")
  394. dtype = state.pop("dtype")
  395. metadata = state.pop("metadata", {})
  396. requires_grad = metadata.pop("requires_grad", None)
  397. snd = mgb.make_shared(device, value=data, dtype=dtype)
  398. self._reset(snd, requires_grad=requires_grad)
  399. def tensor(
  400. data: Union[list, np.ndarray] = None,
  401. *,
  402. dtype: str = None,
  403. device: mgb.CompNode = None,
  404. requires_grad: bool = None
  405. ):
  406. r"""A helper function to create a :class:`~.Tensor` using existing data.
  407. :param data: an existing data array, must be Python list, NumPy array or None.
  408. :param dtype: target Tensor data type, one of ``("uint8", "int8", "int16", "int32", "float32", "float16")``.
  409. :param device: target device for Tensor storing.
  410. :param requires_grad: whether its gradiant will be calculated during :meth:`~.Optimizer.backward`
  411. """
  412. supported_dtypes = ("uint8", "int8", "int16", "int32", "float32", "float16")
  413. if isinstance(data, Tensor):
  414. raise NotImplementedError
  415. if dtype is not None and np.dtype(dtype).name not in supported_dtypes:
  416. raise TypeError("unsupported dtype {}".format(dtype))
  417. if data is not None:
  418. if not isinstance(data, np.ndarray):
  419. data = np.array(data, dtype=dtype)
  420. # In order to accept tensor([1]),
  421. # Automaticlly convert to 32-bit number instead of numpy's default 64-bit when input data is not nparray.
  422. dtype = mgb.to_mgb_supported_dtype(data.dtype)
  423. if dtype is None:
  424. if data.dtype.name not in supported_dtypes:
  425. raise TypeError("unsupported dtype {}".format(data.dtype))
  426. device, _ = _use_default_if_none(device, None)
  427. shared_nd = mgb.make_shared(device, value=data, dtype=dtype)
  428. return Tensor(shared_nd, requires_grad=requires_grad)
  429. class Dict(collections.MutableMapping):
  430. def __init__(self, *args, key=None, **kwargs):
  431. self.data = {}
  432. if key:
  433. self.keyfn = key
  434. for i in args:
  435. self.update(i)
  436. self.update(**kwargs)
  437. @staticmethod
  438. def keyfn(key): # pylint: disable=method-hidden
  439. return key
  440. def __getitem__(self, key):
  441. _, v = self.data[self.keyfn(key)]
  442. return v
  443. def __setitem__(self, key, value):
  444. self.data[self.keyfn(key)] = key, value
  445. def __delitem__(self, key):
  446. del self.data[self.keyfn(key)]
  447. def __iter__(self):
  448. for _, (k, _) in self.data.items():
  449. yield k
  450. def __len__(self):
  451. return len(self.data)
  452. class TensorDict(Dict): # pylint: disable=too-many-ancestors
  453. class keyfn:
  454. def __new__(cls, x: Tensor):
  455. if not isinstance(x, Tensor):
  456. return x
  457. return super().__new__(cls)
  458. def __init__(self, x: Tensor):
  459. self._data = x # do not save id directly to make pickle work
  460. def __hash__(self):
  461. return id(self._data)
  462. def __eq__(self, other):
  463. # pylint: disable=undefined-variable
  464. return isinstance(other, __class__) and id(self._data) == id(other._data)
  465. def __init__(self, *args):
  466. super().__init__(*args)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台