You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. from typing import Iterable, List, Optional, Union
  11. import numpy as np
  12. import megengine._internal as mgb
  13. from megengine._internal import CompGraph, CompNode
  14. from ..core import zeros
  15. from ..core.graph import _use_default_if_none
  16. from ..core.tensor import Tensor, wrap_io_tensor
  17. from .elemwise import ceil
  18. from .utils import _decide_comp_node_and_comp_graph
  19. @wrap_io_tensor
  20. def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
  21. """
  22. Broadcast a tensor to ``shape``
  23. :param inp: The input tensor
  24. :param shape: The target shape
  25. :return: The output tensor
  26. Examples:
  27. .. testcode::
  28. import numpy as np
  29. from megengine import tensor
  30. import megengine.functional as F
  31. data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  32. out = F.broadcast_to(data, (4, 2, 3))
  33. print(out.numpy())
  34. Outputs:
  35. .. testoutput::
  36. [[[0. 1. 2.]
  37. [3. 4. 5.]]
  38. [[0. 1. 2.]
  39. [3. 4. 5.]]
  40. [[0. 1. 2.]
  41. [3. 4. 5.]]
  42. [[0. 1. 2.]
  43. [3. 4. 5.]]]
  44. """
  45. if isinstance(shape, int):
  46. shape = (shape,)
  47. return mgb.opr.broadcast(inp, shape)
  48. def _get_idx(index, axis):
  49. index_dims = len(index.imm_shape)
  50. idx = []
  51. comp_node, comp_graph = _decide_comp_node_and_comp_graph(index)
  52. for i in range(index_dims):
  53. if i != axis:
  54. shape = [1] * index_dims
  55. shape[i] = index.axis_shape(i)
  56. arange = mgb.opr.linspace(
  57. 0,
  58. index.axis_shape(i) - 1,
  59. index.axis_shape(i),
  60. comp_node=comp_node,
  61. comp_graph=comp_graph,
  62. )
  63. arange = (
  64. arange.reshape(*shape)
  65. .broadcast(index.shape)
  66. .reshape(-1)
  67. .astype(np.int32)
  68. )
  69. idx.append(arange)
  70. else:
  71. idx.append(index.reshape(-1))
  72. return tuple(idx)
  73. @wrap_io_tensor
  74. def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
  75. r"""
  76. Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`.
  77. For a 3-D tensor, the output is specified by::
  78. out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
  79. out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
  80. out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
  81. if :attr:`inp` is an n-dimensional tensor with size
  82. :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
  83. then :attr:`index` must be an n-dimensional tensor with size
  84. :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
  85. output will have the same size as :attr:`index`.
  86. :param inp: the source tensor
  87. :param axis: the axis along which to index
  88. :param index: the indices of elements to gather
  89. Examples:
  90. .. testcode::
  91. import megengine.functional as F
  92. from megengine.core import tensor
  93. inp = tensor([
  94. [1,2], [3,4], [5,6],
  95. ])
  96. index = tensor([[0,2], [1,0]])
  97. oup = F.gather(inp, 0, index)
  98. print(oup.numpy())
  99. Outputs:
  100. .. testoutput::
  101. [[1 6]
  102. [3 2]]
  103. """
  104. input_shape = inp.imm_shape
  105. index_shape = index.imm_shape
  106. input_dims = len(input_shape)
  107. index_dims = len(index_shape)
  108. if input_dims != index_dims:
  109. raise ValueError(
  110. "The index tensor must have same dimensions as input tensor, "
  111. "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
  112. )
  113. if axis < 0 or axis >= input_dims:
  114. raise ValueError(
  115. "Index axis {} is output of bounds, should in range [0 {})".format(
  116. axis, input_dims
  117. )
  118. )
  119. for i in range(input_dims):
  120. if i != axis and input_shape[i] != index_shape[i]:
  121. raise ValueError(
  122. "The input {} and index {} must have the same size apart from axis {}".format(
  123. input_shape, index_shape, axis
  124. )
  125. )
  126. idx = _get_idx(index, axis)
  127. return mgb.opr.advanced_indexing(inp)[idx].reshape(
  128. index.shape
  129. ) # pylint: disable=no-member
  130. @wrap_io_tensor
  131. def concat(
  132. inps: Iterable[Tensor],
  133. axis: int = 0,
  134. device: Optional[CompNode] = None,
  135. comp_graph: Optional[CompGraph] = None,
  136. ) -> Tensor:
  137. r"""
  138. Concat some tensors
  139. :param inps: Input tensors to concat
  140. :param axis: the dimension over which the tensors are concatenated. Default: 0
  141. :param device: The comp node output on. Default: None
  142. :param comp_graph: The graph in which output is. Default: None
  143. :return: The output tensor
  144. Examples:
  145. .. testcode::
  146. import numpy as np
  147. from megengine import tensor
  148. import megengine.functional as F
  149. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  150. data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  151. out = F.concat([data1, data2])
  152. print(out.numpy())
  153. Outputs:
  154. .. testoutput::
  155. [[ 0. 1. 2.]
  156. [ 3. 4. 5.]
  157. [ 6. 7. 8.]
  158. [ 9. 10. 11.]]
  159. """
  160. # Output buffer not supported
  161. return mgb.opr.concat(
  162. *list(inps), axis=axis, comp_node=device, comp_graph=comp_graph
  163. )
  164. @wrap_io_tensor
  165. def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
  166. r"""
  167. Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor.
  168. For each value in :attr:`source`, its output index is specified by its index
  169. in :attr:`source` for ``axis != dimension`` and by the corresponding value in
  170. :attr:`index` for ``axis = dimension``.
  171. For a 3-D tensor, :attr:`inp` is updated as::
  172. inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0
  173. inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1
  174. inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2
  175. :attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions.
  176. It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
  177. for all dimensions ``d``.
  178. Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
  179. .. note::
  180. Please notice that, due to performance issues, the result is uncertain on the GPU device
  181. if scatter difference positions from source to the same destination position
  182. regard to index tensor.
  183. Show the case using the following examples, the oup[0][2] is maybe
  184. from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339
  185. if set the index[1][2] from 1 to 0.
  186. :param inp: the inp tensor which to be scattered
  187. :param axis: the axis along which to index
  188. :param index: the indices of elements to scatter
  189. :param source: the source element(s) to scatter
  190. Examples:
  191. .. testcode::
  192. import numpy as np
  193. import megengine.functional as F
  194. from megengine.core import tensor
  195. inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
  196. source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
  197. index = tensor([[0,2,0,2,1],[2,0,1,1,2]])
  198. oup = F.scatter(inp, 0, index,source)
  199. print(oup.numpy())
  200. Outputs:
  201. .. testoutput::
  202. [[0.9935 0.0718 0.2256 0. 0. ]
  203. [0. 0. 0.5939 0.357 0.4396]
  204. [0.7723 0.9465 0. 0.8926 0.4576]]
  205. """
  206. input_shape = inp.imm_shape
  207. index_shape = index.imm_shape
  208. source_shape = source.imm_shape
  209. input_dims = len(input_shape)
  210. index_dims = len(index_shape)
  211. source_dims = len(source_shape)
  212. if input_dims != index_dims or input_dims != source_dims:
  213. raise ValueError("The input, source and index tensor must have same dimensions")
  214. if axis < 0 or axis >= input_dims:
  215. raise ValueError(
  216. "Index axis {} is output of bounds, should in range [0 {})".format(
  217. axis, input_dims
  218. )
  219. )
  220. for i in range(source_dims):
  221. if source_shape[i] > input_shape[i]:
  222. raise ValueError(
  223. "The each shape size for source {} must be less than or equal to input {} ".format(
  224. source_shape, input_shape
  225. )
  226. )
  227. for i in range(index_dims):
  228. if index_shape[i] != source_shape[i]:
  229. raise ValueError(
  230. "The each shape size for index {} must be equal to source {} ".format(
  231. index_shape, source_shape
  232. )
  233. )
  234. for i in range(index_dims):
  235. if i != axis and index_shape[i] > input_shape[i]:
  236. raise ValueError(
  237. "The index {} must be less than or equal to input {} size apart from axis {}".format(
  238. index_shape, input_shape, axis
  239. )
  240. )
  241. idx = _get_idx(index, axis)
  242. return mgb.opr.set_advanced_indexing(inp, source.flatten())[idx]
  243. @wrap_io_tensor
  244. def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
  245. r"""
  246. Select elements either from Tensor x or Tensor y, according to mask.
  247. .. math::
  248. \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i
  249. :param mask: a mask used for choosing x or y
  250. :param x: the first choice
  251. :param y: the second choice
  252. Examples:
  253. .. testcode::
  254. from megengine import tensor
  255. import megengine.functional as F
  256. mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32))
  257. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  258. dtype=np.float32))
  259. y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
  260. out = F.where(mask, x, y)
  261. print(out.numpy())
  262. Outputs:
  263. .. testoutput::
  264. [[1. 6.]
  265. [7. 4.]]
  266. """
  267. v0, index0 = mgb.opr.cond_take(
  268. x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=1
  269. )
  270. v1, index1 = mgb.opr.cond_take(
  271. y, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=0
  272. )
  273. out = x.flatten()
  274. index = mgb.opr.concat(index0, index1, axis=0)
  275. v = mgb.opr.concat(v0, v1, axis=0)
  276. out = mgb.opr.set_advanced_indexing(out, v)[index]
  277. out = out.reshape(x.shape)
  278. return out
  279. @wrap_io_tensor
  280. def cond_take(mask: Tensor, x: Tensor, val=1) -> Tensor:
  281. r"""
  282. Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened.
  283. :param mask: condition param; must be the same shape with data
  284. :param x: input tensor from which to take elements
  285. :param val: value to be compared to by mode
  286. Examples:
  287. .. testcode::
  288. from megengine import tensor
  289. import megengine.functional as F
  290. mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32))
  291. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  292. dtype=np.float32))
  293. v, index = F.cond_take(mask, x, 1)
  294. print(v, index)
  295. Outputs:
  296. .. testoutput::
  297. Tensor([1. 4.]) Tensor([0 3], dtype=int32)
  298. """
  299. v, index = mgb.opr.cond_take(
  300. x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=val
  301. )
  302. return v, index
  303. def shapeof(x: Tensor, axis=None):
  304. r"""
  305. The shape of input tensor.
  306. """
  307. return x.shapeof(axis=axis)
  308. @wrap_io_tensor
  309. def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  310. r"""
  311. Swap shapes and strides according to given pattern
  312. :param inp: Input tensor
  313. :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
  314. * (``'x'``) -> make a 0d (scalar) into a 1d vector
  315. * (0, 1) -> identity for 2d vectors
  316. * (1, 0) -> inverts the first and second dimensions
  317. * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
  318. * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
  319. * (2, 0, 1) -> AxBxC to CxAxB
  320. * (0, ``'x'``, 1) -> AxB to Ax1xB
  321. * (1, ``'x'``, 0) -> AxB to Bx1xA
  322. * (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A)
  323. :return: The output tensor
  324. Examples:
  325. .. testcode::
  326. import numpy as np
  327. from megengine import tensor
  328. import megengine.functional as F
  329. x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
  330. out = F.dimshuffle(x, (1, 0))
  331. print(out.numpy())
  332. Outputs:
  333. .. testoutput::
  334. [[1 0]
  335. [1 0]]
  336. """
  337. return mgb.opr.dimshuffle(inp, pattern)
  338. @wrap_io_tensor
  339. def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
  340. r"""
  341. Reshape a tensor to given target shape; total number of logical elements must
  342. remain unchanged
  343. :param inp: Input tensor
  344. :param target_shape: target shape, the components would be concatenated to form the
  345. target shape, and it can contain an element of -1 representing unspec_axis.
  346. Examples:
  347. .. testcode::
  348. import numpy as np
  349. from megengine import tensor
  350. import megengine.functional as F
  351. x = tensor(np.arange(12, dtype=np.int32))
  352. out = F.reshape(x, (3, 2, 2))
  353. print(out.numpy())
  354. Outputs:
  355. .. testoutput::
  356. [[[ 0 1]
  357. [ 2 3]]
  358. [[ 4 5]
  359. [ 6 7]]
  360. [[ 8 9]
  361. [10 11]]]
  362. """
  363. return mgb.opr.reshape(inp, target_shape)
  364. def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  365. r"""Equivalent to :func:`dimshuffle`
  366. """
  367. return dimshuffle(inp, pattern)
  368. @wrap_io_tensor
  369. def add_axis(inp: Tensor, axis: int) -> Tensor:
  370. r"""
  371. Add dimension before given axis.
  372. :param inp: Input tensor
  373. :param axis: Place of new axes
  374. :return: The output tensor
  375. Examples:
  376. .. testcode::
  377. import numpy as np
  378. from megengine import tensor
  379. import megengine.functional as F
  380. x = tensor([1, 2])
  381. out = F.add_axis(x, 0)
  382. print(out.shape)
  383. Outputs:
  384. .. testoutput::
  385. (1, 2)
  386. """
  387. if not isinstance(axis, int):
  388. raise ValueError("axis must be int, but got type:{}".format(type(axis)))
  389. return mgb.opr.add_axis(inp, axis)
  390. @wrap_io_tensor
  391. def remove_axis(inp: Tensor, axis: int) -> Tensor:
  392. r"""
  393. Remove dimension of shape 1.
  394. :param inp: Input tensor
  395. :param axis: Place of axis to be removed
  396. :return: The output tensor
  397. Examples:
  398. .. testcode::
  399. import numpy as np
  400. from megengine import tensor
  401. import megengine.functional as F
  402. x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  403. out = F.remove_axis(x, 3)
  404. print(out.shape)
  405. Outputs:
  406. .. testoutput::
  407. (1, 1, 2)
  408. """
  409. if not isinstance(axis, int):
  410. raise ValueError("axis must be int, but got type:{}".format(type(axis)))
  411. return mgb.opr.remove_axis(inp, axis)
  412. def linspace(
  413. start: Union[int, float, Tensor],
  414. stop: Union[int, float, Tensor],
  415. num: Union[int, Tensor],
  416. dtype=np.float32,
  417. device: Optional[CompNode] = None,
  418. comp_graph: Optional[CompGraph] = None,
  419. ) -> Tensor:
  420. r"""
  421. Return equally spaced numbers over a specified interval
  422. :param start: Starting value of the squence, shoule be scalar
  423. :param stop: The last value of the squence, shoule be scalar
  424. :param num: number of values to generate
  425. :param dtype: result data type
  426. :return: The generated tensor
  427. Examples:
  428. .. testcode::
  429. import numpy as np
  430. import megengine.functional as F
  431. a = F.linspace(3,10,5)
  432. print(a.numpy())
  433. .. testoutput::
  434. [ 3. 4.75 6.5 8.25 10. ]
  435. """
  436. if dtype is not np.float32:
  437. raise ValueError("linspace is only implemented for float32")
  438. device, comp_graph = _use_default_if_none(device, comp_graph)
  439. ret = Tensor(
  440. mgb.opr.linspace(start, stop, num, comp_node=device, comp_graph=comp_graph)
  441. )
  442. return ret.astype(dtype)
  443. def arange(
  444. start: Union[int, float, Tensor],
  445. end: Union[int, float, Tensor],
  446. step: Union[int, float, Tensor] = 1,
  447. dtype=np.float32,
  448. device: Optional[CompNode] = None,
  449. comp_graph: Optional[CompGraph] = None,
  450. ) -> Tensor:
  451. r"""
  452. Returns a Tensor with values from `start` to `end` with adjacent interval `step`
  453. :param start: starting value of the squence, shoule be scalar
  454. :param end: ending value of the squence, shoule be scalar
  455. :param step: the gap between each pair of adjacent values. Default 1
  456. :param dtype: result data type
  457. :return: The generated tensor
  458. Examples:
  459. .. testcode::
  460. import numpy as np
  461. import megengine.functional as F
  462. a = F.arange(1, 5, 1)
  463. print(a.numpy())
  464. .. testoutput::
  465. [1. 2. 3. 4.]
  466. """
  467. if dtype is not np.float32:
  468. raise ValueError("arange is only implemented for float32")
  469. num = ceil((end - start) / step)
  470. stop = start + step * (num - 1)
  471. ret = linspace(start, stop, num, device=device, comp_graph=comp_graph)
  472. return ret
  473. def zeros_like(inp: Tensor) -> Tensor:
  474. r"""
  475. Returns a zero tensor with the same shape as input tensor
  476. :param inp: input tensor
  477. Examples:
  478. .. testcode::
  479. import numpy as np
  480. from megengine import tensor
  481. import megengine.functional as F
  482. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  483. out = F.zeros_like(inp)
  484. print(out.numpy())
  485. .. testoutput::
  486. [[0 0 0]
  487. [0 0 0]]
  488. """
  489. return zeros(inp.shapeof()).astype(inp.dtype)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台