You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. from typing import Iterable, List, Optional, Union
  11. import numpy as np
  12. import megengine._internal as mgb
  13. from megengine._internal import CompGraph, CompNode
  14. from ..core import zeros
  15. from ..core.graph import _use_default_if_none
  16. from ..core.tensor import Tensor, wrap_io_tensor
  17. from .elemwise import ceil
  18. from .utils import _decide_comp_node_and_comp_graph
  19. @wrap_io_tensor
  20. def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
  21. """
  22. Broadcast a tensor to ``shape``
  23. :param inp: The input tensor
  24. :param shape: The target shape
  25. :return: The output tensor
  26. Examples:
  27. .. testcode::
  28. import numpy as np
  29. from megengine import tensor
  30. import megengine.functional as F
  31. data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  32. out = F.broadcast_to(data, (4, 2, 3))
  33. print(out.numpy())
  34. Outputs:
  35. .. testoutput::
  36. [[[0. 1. 2.]
  37. [3. 4. 5.]]
  38. [[0. 1. 2.]
  39. [3. 4. 5.]]
  40. [[0. 1. 2.]
  41. [3. 4. 5.]]
  42. [[0. 1. 2.]
  43. [3. 4. 5.]]]
  44. """
  45. if isinstance(shape, int):
  46. shape = (shape,)
  47. return mgb.opr.broadcast(inp, shape)
  48. def _get_idx(index, axis):
  49. index_dims = len(index.imm_shape)
  50. idx = []
  51. comp_node, comp_graph = _decide_comp_node_and_comp_graph(index)
  52. for i in range(index_dims):
  53. if i != axis:
  54. shape = [1] * index_dims
  55. shape[i] = index.axis_shape(i)
  56. arange = mgb.opr.linspace(
  57. 0,
  58. index.axis_shape(i) - 1,
  59. index.axis_shape(i),
  60. comp_node=comp_node,
  61. comp_graph=comp_graph,
  62. )
  63. arange = (
  64. arange.reshape(*shape)
  65. .broadcast(index.shape)
  66. .reshape(-1)
  67. .astype(np.int32)
  68. )
  69. idx.append(arange)
  70. else:
  71. idx.append(index.reshape(-1))
  72. return tuple(idx)
  73. @wrap_io_tensor
  74. def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
  75. r"""
  76. Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`.
  77. For a 3-D tensor, the output is specified by::
  78. out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
  79. out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
  80. out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
  81. if :attr:`inp` is an n-dimensional tensor with size
  82. :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
  83. then :attr:`index` must be an n-dimensional tensor with size
  84. :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
  85. output will have the same size as :attr:`index`.
  86. :param inp: the source tensor
  87. :param axis: the axis along which to index
  88. :param index: the indices of elements to gather
  89. Examples:
  90. .. testcode::
  91. import megengine.functional as F
  92. from megengine.core import tensor
  93. inp = tensor([
  94. [1,2], [3,4], [5,6],
  95. ])
  96. index = tensor([[0,2], [1,0]])
  97. oup = F.gather(inp, 0, index)
  98. print(oup.numpy())
  99. Outputs:
  100. .. testoutput::
  101. [[1 6]
  102. [3 2]]
  103. """
  104. input_shape = inp.imm_shape
  105. index_shape = index.imm_shape
  106. input_dims = len(input_shape)
  107. index_dims = len(index_shape)
  108. if input_dims != index_dims:
  109. raise ValueError(
  110. "The index tensor must have same dimensions as input tensor, "
  111. "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
  112. )
  113. if axis < 0 or axis >= input_dims:
  114. raise ValueError(
  115. "Index axis {} is output of bounds, should in range [0 {})".format(
  116. axis, input_dims
  117. )
  118. )
  119. for i in range(input_dims):
  120. if i != axis and input_shape[i] != index_shape[i]:
  121. raise ValueError(
  122. "The input {} and index {} must have the same size apart from axis {}".format(
  123. input_shape, index_shape, axis
  124. )
  125. )
  126. idx = _get_idx(index, axis)
  127. return mgb.opr.advanced_indexing(inp)[idx].reshape(
  128. index.shape
  129. ) # pylint: disable=no-member
  130. @wrap_io_tensor
  131. def concat(
  132. inps: Iterable[Tensor],
  133. axis: int = 0,
  134. device: Optional[CompNode] = None,
  135. comp_graph: Optional[CompGraph] = None,
  136. ) -> Tensor:
  137. r"""
  138. Concat some tensors
  139. :param inps: Input tensors to concat
  140. :param axis: the dimension over which the tensors are concatenated. Default: 0
  141. :param device: The comp node output on. Default: None
  142. :param comp_graph: The graph in which output is. Default: None
  143. :return: The output tensor
  144. Examples:
  145. .. testcode::
  146. import numpy as np
  147. from megengine import tensor
  148. import megengine.functional as F
  149. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  150. data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  151. out = F.concat([data1, data2])
  152. print(out.numpy())
  153. Outputs:
  154. .. testoutput::
  155. [[ 0. 1. 2.]
  156. [ 3. 4. 5.]
  157. [ 6. 7. 8.]
  158. [ 9. 10. 11.]]
  159. """
  160. # Output buffer not supported
  161. return mgb.opr.concat(
  162. *list(inps), axis=axis, comp_node=device, comp_graph=comp_graph
  163. )
  164. @wrap_io_tensor
  165. def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
  166. r"""
  167. Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor.
  168. For each value in :attr:`source`, its output index is specified by its index
  169. in :attr:`source` for ``axis != dimension`` and by the corresponding value in
  170. :attr:`index` for ``axis = dimension``.
  171. For a 3-D tensor, :attr:`inp` is updated as::
  172. inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0
  173. inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1
  174. inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2
  175. :attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions.
  176. It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
  177. for all dimensions ``d``.
  178. Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
  179. :param inp: the inp tensor which to be scattered
  180. :param axis: the axis along which to index
  181. :param index: the indices of elements to scatter
  182. :param source: the source element(s) to scatter
  183. Examples:
  184. .. testcode::
  185. import numpy as np
  186. import megengine.functional as F
  187. from megengine.core import tensor
  188. inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
  189. source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
  190. index = tensor([[0,2,0,2,1],[2,0,0,1,2]])
  191. oup = F.scatter(inp, 0, index,source)
  192. print(oup.numpy())
  193. Outputs:
  194. .. testoutput::
  195. :options: +SKIP
  196. [[0.9935 0.0718 0.5939 0. 0. ]
  197. [0. 0. 0. 0.357 0.4396]
  198. [0.7723 0.9465 0. 0.8926 0.4576]]
  199. """
  200. input_shape = inp.imm_shape
  201. index_shape = index.imm_shape
  202. source_shape = source.imm_shape
  203. input_dims = len(input_shape)
  204. index_dims = len(index_shape)
  205. source_dims = len(source_shape)
  206. if input_dims != index_dims or input_dims != source_dims:
  207. raise ValueError("The input, source and index tensor must have same dimensions")
  208. if axis < 0 or axis >= input_dims:
  209. raise ValueError(
  210. "Index axis {} is output of bounds, should in range [0 {})".format(
  211. axis, input_dims
  212. )
  213. )
  214. for i in range(source_dims):
  215. if source_shape[i] > input_shape[i]:
  216. raise ValueError(
  217. "The each shape size for source {} must be less than or equal to input {} ".format(
  218. source_shape, input_shape
  219. )
  220. )
  221. for i in range(index_dims):
  222. if index_shape[i] != source_shape[i]:
  223. raise ValueError(
  224. "The each shape size for index {} must be equal to source {} ".format(
  225. index_shape, source_shape
  226. )
  227. )
  228. for i in range(index_dims):
  229. if i != axis and index_shape[i] > input_shape[i]:
  230. raise ValueError(
  231. "The index {} must be less than or equal to input {} size apart from axis {}".format(
  232. index_shape, input_shape, axis
  233. )
  234. )
  235. idx = _get_idx(index, axis)
  236. return mgb.opr.set_advanced_indexing(inp, source.flatten())[idx]
  237. @wrap_io_tensor
  238. def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
  239. r"""
  240. Select elements either from Tensor x or Tensor y, according to mask.
  241. .. math::
  242. \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i
  243. :param mask: a mask used for choosing x or y
  244. :param x: the first choice
  245. :param y: the second choice
  246. Examples:
  247. .. testcode::
  248. from megengine import tensor
  249. import megengine.functional as F
  250. mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32))
  251. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  252. dtype=np.float32))
  253. y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
  254. out = F.where(mask, x, y)
  255. print(out.numpy())
  256. Outputs:
  257. .. testoutput::
  258. [[1. 6.]
  259. [7. 4.]]
  260. """
  261. v0, index0 = mgb.opr.cond_take(
  262. x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=1
  263. )
  264. v1, index1 = mgb.opr.cond_take(
  265. y, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=0
  266. )
  267. out = x.flatten()
  268. index = mgb.opr.concat(index0, index1, axis=0)
  269. v = mgb.opr.concat(v0, v1, axis=0)
  270. out = mgb.opr.set_advanced_indexing(out, v)[index]
  271. out = out.reshape(x.shape)
  272. return out
  273. @wrap_io_tensor
  274. def cond_take(mask: Tensor, x: Tensor, val=1) -> Tensor:
  275. r"""
  276. Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened.
  277. :param mask: condition param; must be the same shape with data
  278. :param x: input tensor from which to take elements
  279. :param val: value to be compared to by mode
  280. Examples:
  281. .. testcode::
  282. from megengine import tensor
  283. import megengine.functional as F
  284. mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32))
  285. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  286. dtype=np.float32))
  287. v, index = F.cond_take(mask, x, 1)
  288. print(v, index)
  289. Outputs:
  290. .. testoutput::
  291. Tensor([1. 4.]) Tensor([0 3], dtype=int32)
  292. """
  293. v, index = mgb.opr.cond_take(
  294. x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=val
  295. )
  296. return v, index
  297. def shapeof(x: Tensor, axis=None):
  298. r"""
  299. The shape of input tensor.
  300. """
  301. return x.shapeof(axis=axis)
  302. @wrap_io_tensor
  303. def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  304. r"""
  305. Swap shapes and strides according to given pattern
  306. :param inp: Input tensor
  307. :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
  308. * (``'x'``) -> make a 0d (scalar) into a 1d vector
  309. * (0, 1) -> identity for 2d vectors
  310. * (1, 0) -> inverts the first and second dimensions
  311. * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
  312. * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
  313. * (2, 0, 1) -> AxBxC to CxAxB
  314. * (0, ``'x'``, 1) -> AxB to Ax1xB
  315. * (1, ``'x'``, 0) -> AxB to Bx1xA
  316. * (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A)
  317. :return: The output tensor
  318. Examples:
  319. .. testcode::
  320. import numpy as np
  321. from megengine import tensor
  322. import megengine.functional as F
  323. x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
  324. out = F.dimshuffle(x, (1, 0))
  325. print(out.numpy())
  326. Outputs:
  327. .. testoutput::
  328. [[1 0]
  329. [1 0]]
  330. """
  331. return mgb.opr.dimshuffle(inp, pattern)
  332. @wrap_io_tensor
  333. def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
  334. r"""
  335. Reshape a tensor to given target shape; total number of logical elements must
  336. remain unchanged
  337. :param inp: Input tensor
  338. :param target_shape: target shape, the components would be concatenated to form the
  339. target shape, and it can contain an element of -1 representing unspec_axis.
  340. Examples:
  341. .. testcode::
  342. import numpy as np
  343. from megengine import tensor
  344. import megengine.functional as F
  345. x = tensor(np.arange(12, dtype=np.int32))
  346. out = F.reshape(x, (3, 2, 2))
  347. print(out.numpy())
  348. Outputs:
  349. .. testoutput::
  350. [[[ 0 1]
  351. [ 2 3]]
  352. [[ 4 5]
  353. [ 6 7]]
  354. [[ 8 9]
  355. [10 11]]]
  356. """
  357. return mgb.opr.reshape(inp, target_shape)
  358. def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  359. r"""Equivalent to :func:`dimshuffle`
  360. """
  361. return dimshuffle(inp, pattern)
  362. @wrap_io_tensor
  363. def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
  364. r"""
  365. Add dimension(s) before given axis/axes
  366. :param inp: Input tensor
  367. :param axis: Place(s) of new axes
  368. :return: The output tensor
  369. Examples:
  370. .. testcode::
  371. import numpy as np
  372. from megengine import tensor
  373. import megengine.functional as F
  374. x = tensor([1, 2])
  375. out = F.add_axis(x, (0, 2))
  376. print(out.shape)
  377. Outputs:
  378. .. testoutput::
  379. (1, 2, 1)
  380. """
  381. return mgb.opr.add_axis(inp, axis)
  382. @wrap_io_tensor
  383. def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
  384. r"""
  385. Remove dimension(s) of shape 1
  386. :param inp: Input tensor
  387. :param axis: Place(s) of axes to be removed
  388. :return: The output tensor
  389. Examples:
  390. .. testcode::
  391. import numpy as np
  392. from megengine import tensor
  393. import megengine.functional as F
  394. x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  395. out = F.remove_axis(x, (0, 0, 1))
  396. print(out.shape)
  397. Outputs:
  398. .. testoutput::
  399. (2,)
  400. """
  401. return mgb.opr.remove_axis(inp, axis)
  402. def linspace(
  403. start: Union[int, float, Tensor],
  404. stop: Union[int, float, Tensor],
  405. num: Union[int, Tensor],
  406. dtype=np.float32,
  407. device: Optional[CompNode] = None,
  408. comp_graph: Optional[CompGraph] = None,
  409. ) -> Tensor:
  410. r"""
  411. Return equally spaced numbers over a specified interval
  412. :param start: Starting value of the squence, shoule be scalar
  413. :param stop: The last value of the squence, shoule be scalar
  414. :param num: number of values to generate
  415. :param dtype: result data type
  416. :return: The generated tensor
  417. Examples:
  418. .. testcode::
  419. import numpy as np
  420. import megengine.functional as F
  421. a = F.linspace(3,10,5)
  422. print(a.numpy())
  423. .. testoutput::
  424. [ 3. 4.75 6.5 8.25 10. ]
  425. """
  426. if dtype is not np.float32:
  427. raise ValueError("linspace is only implemented for float32")
  428. device, comp_graph = _use_default_if_none(device, comp_graph)
  429. ret = Tensor(
  430. mgb.opr.linspace(start, stop, num, comp_node=device, comp_graph=comp_graph)
  431. )
  432. return ret.astype(dtype)
  433. def arange(
  434. start: Union[int, float, Tensor],
  435. end: Union[int, float, Tensor],
  436. step: Union[int, float, Tensor] = 1,
  437. dtype=np.float32,
  438. device: Optional[CompNode] = None,
  439. comp_graph: Optional[CompGraph] = None,
  440. ) -> Tensor:
  441. r"""
  442. Returns a Tensor with values from `start` to `end` with adjacent interval `step`
  443. :param start: starting value of the squence, shoule be scalar
  444. :param end: ending value of the squence, shoule be scalar
  445. :param step: the gap between each pair of adjacent values. Default 1
  446. :param dtype: result data type
  447. :return: The generated tensor
  448. Examples:
  449. .. testcode::
  450. import numpy as np
  451. import megengine.functional as F
  452. a = F.arange(1, 5, 1)
  453. print(a.numpy())
  454. .. testoutput::
  455. [1. 2. 3. 4.]
  456. """
  457. if dtype is not np.float32:
  458. raise ValueError("arange is only implemented for float32")
  459. num = ceil((end - start) / step)
  460. stop = start + step * (num - 1)
  461. ret = linspace(start, stop, num, device=device, comp_graph=comp_graph)
  462. return ret
  463. def zeros_like(inp: Tensor) -> Tensor:
  464. r"""
  465. Returns a zero tensor with the same shape as input tensor
  466. :param inp: input tensor
  467. Examples:
  468. .. testcode::
  469. import numpy as np
  470. from megengine import tensor
  471. import megengine.functional as F
  472. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  473. out = F.zeros_like(inp)
  474. print(out.numpy())
  475. .. testoutput::
  476. [[0 0 0]
  477. [0 0 0]]
  478. """
  479. return zeros(inp.shapeof()).astype(inp.dtype)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台