You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. from typing import Iterable, List, Optional, Union
  11. import numpy as np
  12. import megengine._internal as mgb
  13. from megengine._internal import CompGraph, CompNode
  14. from ..core import zeros
  15. from ..core.graph import _use_default_if_none
  16. from ..core.tensor import Tensor, wrap_io_tensor
  17. from .elemwise import ceil
  18. from .utils import _decide_comp_node_and_comp_graph
  19. @wrap_io_tensor
  20. def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
  21. """
  22. Broadcast a tensor to ``shape``
  23. :param inp: The input tensor
  24. :param shape: The target shape
  25. :return: The output tensor
  26. Examples:
  27. .. testcode::
  28. import numpy as np
  29. from megengine import tensor
  30. import megengine.functional as F
  31. data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  32. out = F.broadcast_to(data, (4, 2, 3))
  33. print(out.numpy())
  34. Outputs:
  35. .. testoutput::
  36. [[[0. 1. 2.]
  37. [3. 4. 5.]]
  38. [[0. 1. 2.]
  39. [3. 4. 5.]]
  40. [[0. 1. 2.]
  41. [3. 4. 5.]]
  42. [[0. 1. 2.]
  43. [3. 4. 5.]]]
  44. """
  45. if isinstance(shape, int):
  46. shape = (shape,)
  47. return mgb.opr.broadcast(inp, shape)
  48. def _get_idx(index, axis):
  49. index_dims = len(index.imm_shape)
  50. idx = []
  51. comp_node, comp_graph = _decide_comp_node_and_comp_graph(index)
  52. for i in range(index_dims):
  53. if i != axis:
  54. shape = [1] * index_dims
  55. shape[i] = index.axis_shape(i)
  56. arange = mgb.opr.linspace(
  57. 0,
  58. index.axis_shape(i) - 1,
  59. index.axis_shape(i),
  60. comp_node=comp_node,
  61. comp_graph=comp_graph,
  62. )
  63. arange = (
  64. arange.reshape(*shape)
  65. .broadcast(index.shape)
  66. .reshape(-1)
  67. .astype(np.int32)
  68. )
  69. idx.append(arange)
  70. else:
  71. idx.append(index.reshape(-1))
  72. return tuple(idx)
  73. @wrap_io_tensor
  74. def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
  75. r"""
  76. Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`.
  77. For a 3-D tensor, the output is specified by::
  78. out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
  79. out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
  80. out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
  81. if :attr:`inp` is an n-dimensional tensor with size
  82. :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
  83. then :attr:`index` must be an n-dimensional tensor with size
  84. :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
  85. output will have the same size as :attr:`index`.
  86. :param inp: the source tensor
  87. :param axis: the axis along which to index
  88. :param index: the indices of elements to gather
  89. Examples:
  90. .. testcode::
  91. import megengine.functional as F
  92. from megengine.core import tensor
  93. inp = tensor([
  94. [1,2], [3,4], [5,6],
  95. ])
  96. index = tensor([[0,2], [1,0]])
  97. oup = F.gather(inp, 0, index)
  98. print(oup.numpy())
  99. Outputs:
  100. .. testoutput::
  101. [[1 6]
  102. [3 2]]
  103. """
  104. input_shape = inp.imm_shape
  105. index_shape = index.imm_shape
  106. input_dims = len(input_shape)
  107. index_dims = len(index_shape)
  108. if input_dims != index_dims:
  109. raise ValueError(
  110. "The index tensor must have same dimensions as input tensor, "
  111. "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
  112. )
  113. if axis < 0 or axis >= input_dims:
  114. raise ValueError(
  115. "Index axis {} is output of bounds, should in range [0 {})".format(
  116. axis, input_dims
  117. )
  118. )
  119. for i in range(input_dims):
  120. if i != axis and input_shape[i] != index_shape[i]:
  121. raise ValueError(
  122. "The input {} and index {} must have the same size apart from axis {}".format(
  123. input_shape, index_shape, axis
  124. )
  125. )
  126. idx = _get_idx(index, axis)
  127. return mgb.opr.advanced_indexing(inp)[idx].reshape(
  128. index.shape
  129. ) # pylint: disable=no-member
  130. @wrap_io_tensor
  131. def concat(
  132. inps: Iterable[Tensor],
  133. axis: int = 0,
  134. device: Optional[CompNode] = None,
  135. comp_graph: Optional[CompGraph] = None,
  136. ) -> Tensor:
  137. r"""
  138. Concat some tensors
  139. :param inps: Input tensors to concat
  140. :param axis: the dimension over which the tensors are concatenated,
  141. default to 0
  142. :param device: The comp node output on, default to None
  143. :param comp_graph: The graph in which output is, default to None
  144. :return: The output tensor
  145. Examples:
  146. .. testcode::
  147. import numpy as np
  148. from megengine import tensor
  149. import megengine.functional as F
  150. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  151. data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  152. out = F.concat([data1, data2])
  153. print(out.numpy())
  154. Outputs:
  155. .. testoutput::
  156. [[ 0. 1. 2.]
  157. [ 3. 4. 5.]
  158. [ 6. 7. 8.]
  159. [ 9. 10. 11.]]
  160. """
  161. # Output buffer not supported
  162. return mgb.opr.concat(
  163. *list(inps), axis=axis, comp_node=device, comp_graph=comp_graph
  164. )
  165. @wrap_io_tensor
  166. def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
  167. r"""
  168. Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor.
  169. For each value in :attr:`source`, its output index is specified by its index
  170. in :attr:`source` for ``axis != dimension`` and by the corresponding value in
  171. :attr:`index` for ``axis = dimension``.
  172. For a 3-D tensor, :attr:`inp` is updated as::
  173. inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0
  174. inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1
  175. inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2
  176. :attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions.
  177. It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
  178. for all dimensions ``d``.
  179. Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
  180. :param inp: the inp tensor which to be scattered
  181. :param axis: the axis along which to index
  182. :param index: the indices of elements to scatter
  183. :param source: the source element(s) to scatter
  184. Examples:
  185. .. testcode::
  186. import numpy as np
  187. import megengine.functional as F
  188. from megengine.core import tensor
  189. inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
  190. source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
  191. index = tensor([[0,2,0,2,1],[2,0,0,1,2]])
  192. oup = F.scatter(inp, 0, index,source)
  193. print(oup.numpy())
  194. Outputs:
  195. .. testoutput::
  196. :options: +SKIP
  197. [[0.9935 0.0718 0.5939 0. 0. ]
  198. [0. 0. 0. 0.357 0.4396]
  199. [0.7723 0.9465 0. 0.8926 0.4576]]
  200. """
  201. input_shape = inp.imm_shape
  202. index_shape = index.imm_shape
  203. source_shape = source.imm_shape
  204. input_dims = len(input_shape)
  205. index_dims = len(index_shape)
  206. source_dims = len(source_shape)
  207. if input_dims != index_dims or input_dims != source_dims:
  208. raise ValueError("The input, source and index tensor must have same dimensions")
  209. if axis < 0 or axis >= input_dims:
  210. raise ValueError(
  211. "Index axis {} is output of bounds, should in range [0 {})".format(
  212. axis, input_dims
  213. )
  214. )
  215. for i in range(source_dims):
  216. if source_shape[i] > input_shape[i]:
  217. raise ValueError(
  218. "The each shape size for source {} must be less than or equal to input {} ".format(
  219. source_shape, input_shape
  220. )
  221. )
  222. for i in range(index_dims):
  223. if index_shape[i] != source_shape[i]:
  224. raise ValueError(
  225. "The each shape size for index {} must be equal to source {} ".format(
  226. index_shape, source_shape
  227. )
  228. )
  229. for i in range(index_dims):
  230. if i != axis and index_shape[i] > input_shape[i]:
  231. raise ValueError(
  232. "The index {} must be less than or equal to input {} size apart from axis {}".format(
  233. index_shape, input_shape, axis
  234. )
  235. )
  236. idx = _get_idx(index, axis)
  237. return mgb.opr.set_advanced_indexing(inp, source.flatten())[idx]
  238. @wrap_io_tensor
  239. def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
  240. r"""
  241. Select elements either from Tensor x or Tensor y, according to mask.
  242. .. math::
  243. \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i
  244. :param mask: a mask used for choosing x or y
  245. :param x: the first choice
  246. :param y: the second choice
  247. Examples:
  248. .. testcode::
  249. from megengine import tensor
  250. import megengine.functional as F
  251. mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32))
  252. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  253. dtype=np.float32))
  254. y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
  255. out = F.where(mask, x, y)
  256. print(out.numpy())
  257. Outputs:
  258. .. testoutput::
  259. [[1. 6.]
  260. [7. 4.]]
  261. """
  262. v0, index0 = mgb.opr.cond_take(
  263. x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=1
  264. )
  265. v1, index1 = mgb.opr.cond_take(
  266. y, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=0
  267. )
  268. out = x.flatten()
  269. out = mgb.opr.set_advanced_indexing(out, v0)[index0]
  270. out = mgb.opr.set_advanced_indexing(out, v1)[index1]
  271. out = out.reshape(x.shape)
  272. return out
  273. def shapeof(x: Tensor, axis=None):
  274. r"""
  275. The shape of input tensor.
  276. """
  277. return x.shapeof(axis=axis)
  278. @wrap_io_tensor
  279. def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  280. r"""
  281. Swap shapes and strides according to given pattern
  282. :param inp: Input tensor
  283. :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
  284. * (``'x'``) -> make a 0d (scalar) into a 1d vector
  285. * (0, 1) -> identity for 2d vectors
  286. * (1, 0) -> inverts the first and second dimensions
  287. * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
  288. * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
  289. * (2, 0, 1) -> AxBxC to CxAxB
  290. * (0, ``'x'``, 1) -> AxB to Ax1xB
  291. * (1, ``'x'``, 0) -> AxB to Bx1xA
  292. * (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A)
  293. :return: The output tensor
  294. Examples:
  295. .. testcode::
  296. import numpy as np
  297. from megengine import tensor
  298. import megengine.functional as F
  299. x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
  300. out = F.dimshuffle(x, (1, 0))
  301. print(out.numpy())
  302. Outputs:
  303. .. testoutput::
  304. [[1 0]
  305. [1 0]]
  306. """
  307. return mgb.opr.dimshuffle(inp, pattern)
  308. @wrap_io_tensor
  309. def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
  310. r"""
  311. Reshape a tensor to given target shape; total number of logical elements must
  312. remain unchanged
  313. :param inp: Input tensor
  314. :param target_shape: target shape, the components would be concatenated to form the
  315. target shape, and it can contain an element of -1 representing unspec_axis.
  316. Examples:
  317. .. testcode::
  318. import numpy as np
  319. from megengine import tensor
  320. import megengine.functional as F
  321. x = tensor(np.arange(12, dtype=np.int32))
  322. out = F.reshape(x, (3, 2, 2))
  323. print(out.numpy())
  324. Outputs:
  325. .. testoutput::
  326. [[[ 0 1]
  327. [ 2 3]]
  328. [[ 4 5]
  329. [ 6 7]]
  330. [[ 8 9]
  331. [10 11]]]
  332. """
  333. return mgb.opr.reshape(inp, target_shape)
  334. def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  335. r"""Equivalent to :func:`dimshuffle`
  336. """
  337. return dimshuffle(inp, pattern)
  338. @wrap_io_tensor
  339. def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
  340. r"""
  341. Add dimension(s) before given axis/axes
  342. :param inp: Input tensor
  343. :param axis: Place(s) of new axes
  344. :return: The output tensor
  345. Examples:
  346. .. testcode::
  347. import numpy as np
  348. from megengine import tensor
  349. import megengine.functional as F
  350. x = tensor([1, 2])
  351. out = F.add_axis(x, (0, 2))
  352. print(out.shape)
  353. Outputs:
  354. .. testoutput::
  355. (1, 2, 1)
  356. """
  357. return mgb.opr.add_axis(inp, axis)
  358. @wrap_io_tensor
  359. def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
  360. r"""
  361. Remove dimension(s) of shape 1
  362. :param inp: Input tensor
  363. :param axis: Place(s) of axes to be removed
  364. :return: The output tensor
  365. Examples:
  366. .. testcode::
  367. import numpy as np
  368. from megengine import tensor
  369. import megengine.functional as F
  370. x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  371. out = F.remove_axis(x, (0, 0, 1))
  372. print(out.shape)
  373. Outputs:
  374. .. testoutput::
  375. (2,)
  376. """
  377. return mgb.opr.remove_axis(inp, axis)
  378. def linspace(
  379. start: Union[int, float, Tensor],
  380. stop: Union[int, float, Tensor],
  381. num: int = 100,
  382. dtype=np.float32,
  383. device: Optional[CompNode] = None,
  384. comp_graph: Optional[CompGraph] = None,
  385. ) -> Tensor:
  386. r"""
  387. Return equally spaced numbers over a specified interval
  388. :param start: Starting value of the squence, shoule be scalar
  389. :param stop: The last value of the squence, shoule be scalar
  390. :param num: number of values to generate
  391. :param dtype: result data type
  392. :return: The generated tensor
  393. Examples:
  394. .. testcode::
  395. import numpy as np
  396. import megengine.functional as F
  397. a = F.linspace(3,10,5)
  398. print(a.numpy())
  399. .. testoutput::
  400. [ 3. 4.75 6.5 8.25 10. ]
  401. """
  402. if dtype is not np.float32:
  403. raise ValueError("linspace is only implemented for float32")
  404. device, comp_graph = _use_default_if_none(device, comp_graph)
  405. ret = Tensor(
  406. mgb.opr.linspace(start, stop, num, comp_node=device, comp_graph=comp_graph)
  407. )
  408. return ret.astype(dtype)
  409. def arange(
  410. start: Union[int, float, Tensor],
  411. end: Union[int, float, Tensor],
  412. step: Union[int, float, Tensor] = 1,
  413. dtype=np.float32,
  414. device: Optional[CompNode] = None,
  415. comp_graph: Optional[CompGraph] = None,
  416. ) -> Tensor:
  417. r"""
  418. Returns a Tensor with values from `start` to `end` with adjacent interval `step`
  419. :param start: starting value of the squence, shoule be scalar
  420. :param end: ending value of the squence, shoule be scalar
  421. :param step: the gap between each pair of adjacent values. Default 1
  422. :param dtype: result data type
  423. :return: The generated tensor
  424. Examples:
  425. .. testcode::
  426. import numpy as np
  427. import megengine.functional as F
  428. a = F.arange(1, 5, 1)
  429. print(a.numpy())
  430. .. testoutput::
  431. [1. 2. 3. 4.]
  432. """
  433. if dtype is not np.float32:
  434. raise ValueError("arange is only implemented for float32")
  435. num = ceil((end - start) / step)
  436. stop = start + step * (num - 1)
  437. ret = linspace(start, stop, num, device=device, comp_graph=comp_graph)
  438. return ret
  439. def zeros_like(inp: Tensor) -> Tensor:
  440. r"""
  441. Returns a zero tensor with the same shape as input tensor
  442. :param inp: input tensor
  443. Examples:
  444. .. testcode::
  445. import numpy as np
  446. from megengine import tensor
  447. import megengine.functional as F
  448. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  449. out = F.zeros_like(inp)
  450. print(out.numpy())
  451. .. testoutput::
  452. [[0 0 0]
  453. [0 0 0]]
  454. """
  455. return zeros(inp.shapeof()).astype(inp.dtype)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台