You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. from typing import Iterable, List, Optional, Union
  11. import numpy as np
  12. import megengine._internal as mgb
  13. from megengine._internal import CompGraph, CompNode
  14. from ..core import zeros
  15. from ..core.graph import _use_default_if_none
  16. from ..core.tensor import Tensor, wrap_io_tensor
  17. from .utils import _decide_comp_node_and_comp_graph
  18. @wrap_io_tensor
  19. def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
  20. """
  21. Broadcast a tensor to ``shape``
  22. :param inp: The input tensor
  23. :param shape: The target shape
  24. :return: The output tensor
  25. Examples:
  26. .. testcode::
  27. import numpy as np
  28. from megengine import tensor
  29. import megengine.functional as F
  30. data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  31. out = F.broadcast_to(data, (4, 2, 3))
  32. print(out.numpy())
  33. Outputs:
  34. .. testoutput::
  35. [[[0. 1. 2.]
  36. [3. 4. 5.]]
  37. [[0. 1. 2.]
  38. [3. 4. 5.]]
  39. [[0. 1. 2.]
  40. [3. 4. 5.]]
  41. [[0. 1. 2.]
  42. [3. 4. 5.]]]
  43. """
  44. if isinstance(shape, int):
  45. shape = (shape,)
  46. return mgb.opr.broadcast(inp, shape)
  47. def _get_idx(index, axis):
  48. index_dims = len(index.imm_shape)
  49. idx = []
  50. comp_node, comp_graph = _decide_comp_node_and_comp_graph(index)
  51. for i in range(index_dims):
  52. if i != axis:
  53. shape = [1] * index_dims
  54. shape[i] = index.axis_shape(i)
  55. arange = mgb.opr.linspace(
  56. 0,
  57. index.axis_shape(i) - 1,
  58. index.axis_shape(i),
  59. comp_node=comp_node,
  60. comp_graph=comp_graph,
  61. )
  62. arange = (
  63. arange.reshape(*shape)
  64. .broadcast(index.shape)
  65. .reshape(-1)
  66. .astype(np.int32)
  67. )
  68. idx.append(arange)
  69. else:
  70. idx.append(index.reshape(-1))
  71. return tuple(idx)
  72. @wrap_io_tensor
  73. def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
  74. r"""
  75. Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`.
  76. For a 3-D tensor, the output is specified by::
  77. out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
  78. out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
  79. out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
  80. if :attr:`inp` is an n-dimensional tensor with size
  81. :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
  82. then :attr:`index` must be an n-dimensional tensor with size
  83. :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
  84. output will have the same size as :attr:`index`.
  85. :param inp: the source tensor
  86. :param axis: the axis along which to index
  87. :param index: the indices of elements to gather
  88. Examples:
  89. .. testcode::
  90. import megengine.functional as F
  91. from megengine.core import tensor
  92. inp = tensor([
  93. [1,2], [3,4], [5,6],
  94. ])
  95. index = tensor([[0,2], [1,0]])
  96. oup = F.gather(inp, 0, index)
  97. print(oup.numpy())
  98. Outputs:
  99. .. testoutput::
  100. [[1 6]
  101. [3 2]]
  102. """
  103. input_shape = inp.imm_shape
  104. index_shape = index.imm_shape
  105. input_dims = len(input_shape)
  106. index_dims = len(index_shape)
  107. if input_dims != index_dims:
  108. raise ValueError(
  109. "The index tensor must have same dimensions as input tensor, "
  110. "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
  111. )
  112. if axis < 0 or axis >= input_dims:
  113. raise ValueError(
  114. "Index axis {} is output of bounds, should in range [0 {})".format(
  115. axis, input_dims
  116. )
  117. )
  118. for i in range(input_dims):
  119. if i != axis and input_shape[i] != index_shape[i]:
  120. raise ValueError(
  121. "The input {} and index {} must have the same size apart from axis {}".format(
  122. input_shape, index_shape, axis
  123. )
  124. )
  125. idx = _get_idx(index, axis)
  126. return mgb.opr.advanced_indexing(inp)[idx].reshape(
  127. index.shape
  128. ) # pylint: disable=no-member
  129. @wrap_io_tensor
  130. def concat(
  131. inps: Iterable[Tensor],
  132. axis: int = 0,
  133. device: Optional[CompNode] = None,
  134. comp_graph: Optional[CompGraph] = None,
  135. ) -> Tensor:
  136. r"""
  137. Concat some tensors
  138. :param inps: Input tensors to concat
  139. :param axis: the dimension over which the tensors are concatenated,
  140. default to 0
  141. :param device: The comp node output on, default to None
  142. :param comp_graph: The graph in which output is, default to None
  143. :return: The output tensor
  144. Examples:
  145. .. testcode::
  146. import numpy as np
  147. from megengine import tensor
  148. import megengine.functional as F
  149. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  150. data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  151. out = F.concat([data1, data2])
  152. print(out.numpy())
  153. Outputs:
  154. .. testoutput::
  155. [[ 0. 1. 2.]
  156. [ 3. 4. 5.]
  157. [ 6. 7. 8.]
  158. [ 9. 10. 11.]]
  159. """
  160. # Output buffer not supported
  161. return mgb.opr.concat(
  162. *list(inps), axis=axis, comp_node=device, comp_graph=comp_graph
  163. )
  164. @wrap_io_tensor
  165. def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
  166. r"""
  167. Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor.
  168. For each value in :attr:`source`, its output index is specified by its index
  169. in :attr:`source` for ``axis != dimension`` and by the corresponding value in
  170. :attr:`index` for ``axis = dimension``.
  171. For a 3-D tensor, :attr:`inp` is updated as::
  172. inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0
  173. inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1
  174. inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2
  175. :attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions.
  176. It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
  177. for all dimensions ``d``.
  178. Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
  179. :param inp: the inp tensor which to be scattered
  180. :param axis: the axis along which to index
  181. :param index: the indices of elements to scatter
  182. :param source: the source element(s) to scatter
  183. Examples:
  184. .. testcode::
  185. import numpy as np
  186. import megengine.functional as F
  187. from megengine.core import tensor
  188. inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
  189. source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
  190. index = tensor([[0,2,0,2,1],[2,0,0,1,2]])
  191. oup = F.scatter(inp, 0, index,source)
  192. print(oup.numpy())
  193. Outputs:
  194. .. testoutput::
  195. [[0.9935 0.0718 0.5939 0. 0. ]
  196. [0. 0. 0. 0.357 0.4396]
  197. [0.7723 0.9465 0. 0.8926 0.4576]]
  198. """
  199. input_shape = inp.imm_shape
  200. index_shape = index.imm_shape
  201. source_shape = source.imm_shape
  202. input_dims = len(input_shape)
  203. index_dims = len(index_shape)
  204. source_dims = len(source_shape)
  205. if input_dims != index_dims or input_dims != source_dims:
  206. raise ValueError("The input, source and index tensor must have same dimensions")
  207. if axis < 0 or axis >= input_dims:
  208. raise ValueError(
  209. "Index axis {} is output of bounds, should in range [0 {})".format(
  210. axis, input_dims
  211. )
  212. )
  213. for i in range(source_dims):
  214. if source_shape[i] > input_shape[i]:
  215. raise ValueError(
  216. "The each shape size for source {} must be less than or equal to input {} ".format(
  217. source_shape, input_shape
  218. )
  219. )
  220. for i in range(index_dims):
  221. if index_shape[i] != source_shape[i]:
  222. raise ValueError(
  223. "The each shape size for index {} must be equal to source {} ".format(
  224. index_shape, source_shape
  225. )
  226. )
  227. for i in range(index_dims):
  228. if i != axis and index_shape[i] > input_shape[i]:
  229. raise ValueError(
  230. "The index {} must be less than or equal to input {} size apart from axis {}".format(
  231. index_shape, input_shape, axis
  232. )
  233. )
  234. idx = _get_idx(index, axis)
  235. return mgb.opr.set_advanced_indexing(inp, source.flatten())[idx]
  236. @wrap_io_tensor
  237. def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
  238. r"""
  239. Select elements either from Tensor x or Tensor y, according to mask.
  240. .. math::
  241. \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i
  242. :param mask: a mask used for choosing x or y
  243. :param x: the first choice
  244. :param y: the second choice
  245. Examples:
  246. .. testcode::
  247. from megengine import tensor
  248. import megengine.functional as F
  249. mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32))
  250. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  251. dtype=np.float32))
  252. y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
  253. out = F.where(mask, x, y)
  254. print(out.numpy())
  255. Outputs:
  256. .. testoutput::
  257. [[1. 6.]
  258. [7. 4.]]
  259. """
  260. v0, index0 = mgb.opr.cond_take(
  261. x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=1
  262. )
  263. v1, index1 = mgb.opr.cond_take(
  264. y, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=0
  265. )
  266. out = x.flatten()
  267. out = mgb.opr.set_advanced_indexing(out, v0)[index0]
  268. out = mgb.opr.set_advanced_indexing(out, v1)[index1]
  269. out = out.reshape(x.shape)
  270. return out
  271. def shapeof(x: Tensor, axis=None):
  272. r"""
  273. The shape of input tensor.
  274. """
  275. return x.shapeof(axis=axis)
  276. @wrap_io_tensor
  277. def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  278. r"""
  279. Swap shapes and strides according to given pattern
  280. :param inp: Input tensor
  281. :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
  282. * (``'x'``) -> make a 0d (scalar) into a 1d vector
  283. * (0, 1) -> identity for 2d vectors
  284. * (1, 0) -> inverts the first and second dimensions
  285. * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
  286. * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
  287. * (2, 0, 1) -> AxBxC to CxAxB
  288. * (0, ``'x'``, 1) -> AxB to Ax1xB
  289. * (1, ``'x'``, 0) -> AxB to Bx1xA
  290. * (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A)
  291. :return: The output tensor
  292. Examples:
  293. .. testcode::
  294. import numpy as np
  295. from megengine import tensor
  296. import megengine.functional as F
  297. x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
  298. out = F.dimshuffle(x, (1, 0))
  299. print(out.numpy())
  300. Outputs:
  301. .. testoutput::
  302. [[1 0]
  303. [1 0]]
  304. """
  305. return mgb.opr.dimshuffle(inp, pattern)
  306. @wrap_io_tensor
  307. def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
  308. r"""
  309. Reshape a tensor to given target shape; total number of logical elements must
  310. remain unchanged
  311. :param inp: Input tensor
  312. :param target_shape: target shape, the components would be concatenated to form the
  313. target shape, and it can contain an element of -1 representing unspec_axis.
  314. Examples:
  315. .. testcode::
  316. import numpy as np
  317. from megengine import tensor
  318. import megengine.functional as F
  319. x = tensor(np.arange(12, dtype=np.int32))
  320. out = F.reshape(x, (3, 2, 2))
  321. print(out.numpy())
  322. Outputs:
  323. .. testoutput::
  324. [[[ 0 1]
  325. [ 2 3]]
  326. [[ 4 5]
  327. [ 6 7]]
  328. [[ 8 9]
  329. [10 11]]]
  330. """
  331. return mgb.opr.reshape(inp, target_shape)
  332. @functools.wraps(dimshuffle)
  333. def transpose(*args, **kwargs):
  334. r"""See :func:`dimshuffle`
  335. """
  336. return dimshuffle(*args, **kwargs)
  337. @wrap_io_tensor
  338. def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
  339. r"""
  340. Add dimension(s) before given axis/axes
  341. :param inp: Input tensor
  342. :param axis: Place(s) of new axes
  343. :return: The output tensor
  344. Examples:
  345. .. testcode::
  346. import numpy as np
  347. from megengine import tensor
  348. import megengine.functional as F
  349. x = tensor([1, 2])
  350. out = F.add_axis(x, (0, 2))
  351. print(out.shape)
  352. Outputs:
  353. .. testoutput::
  354. (1, 2, 1)
  355. """
  356. return mgb.opr.add_axis(inp, axis)
  357. @wrap_io_tensor
  358. def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
  359. r"""
  360. Remove dimension(s) of shape 1
  361. :param inp: Input tensor
  362. :param axis: Place(s) of axes to be removed
  363. :return: The output tensor
  364. Examples:
  365. .. testcode::
  366. import numpy as np
  367. from megengine import tensor
  368. import megengine.functional as F
  369. x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  370. out = F.remove_axis(x, (0, 0, 1))
  371. print(out.shape)
  372. Outputs:
  373. .. testoutput::
  374. (2,)
  375. """
  376. return mgb.opr.remove_axis(inp, axis)
  377. def linspace(
  378. start: Union[int, float, Tensor],
  379. stop: Union[int, float, Tensor],
  380. num: int = 100,
  381. dtype=np.float32,
  382. device: Optional[CompNode] = None,
  383. comp_graph: Optional[CompGraph] = None,
  384. ) -> Tensor:
  385. r"""
  386. Return equally spaced numbers over a specified interval
  387. :param start: Starting value of the squence, shoule be scalar
  388. :param stop: The last value of the squence, shoule be scalar
  389. :param num: number of values to generate
  390. :param dtype: result data type
  391. :return: The generated tensor
  392. Examples:
  393. .. testcode::
  394. import numpy as np
  395. import megengine.functional as F
  396. a = F.linspace(3,10,5)
  397. print(a.numpy())
  398. .. testoutput::
  399. [ 3. 4.75 6.5 8.25 10. ]
  400. """
  401. if dtype is not np.float32:
  402. raise ValueError("linspace is only implemented for float32")
  403. device, comp_graph = _use_default_if_none(device, comp_graph)
  404. ret = Tensor(
  405. mgb.opr.linspace(start, stop, num, comp_node=device, comp_graph=comp_graph)
  406. )
  407. return ret.astype(dtype)
  408. def zeros_like(inp: Tensor) -> Tensor:
  409. r"""
  410. Returns a zero tensor with the same shape as input tensor
  411. :param inp: input tensor
  412. Examples:
  413. .. testcode::
  414. import numpy as np
  415. from megengine import tensor
  416. import megengine.functional as F
  417. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  418. out = F.zeros_like(inp)
  419. print(out.numpy())
  420. .. testoutput::
  421. [[0 0 0]
  422. [0 0 0]]
  423. """
  424. return zeros(inp.shapeof()).astype(inp.dtype)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)