You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. from typing import Iterable, List, Optional, Union
  11. import numpy as np
  12. import megengine._internal as mgb
  13. from megengine._internal import CompGraph, CompNode
  14. from ..core import zeros
  15. from ..core.graph import _use_default_if_none
  16. from ..core.tensor import Tensor, wrap_io_tensor
  17. from .utils import _decide_comp_node_and_comp_graph
  18. @wrap_io_tensor
  19. def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
  20. """
  21. Broadcast a tensor to ``shape``
  22. :param inp: The input tensor
  23. :param shape: The target shape
  24. :return: The output tensor
  25. Examples:
  26. .. testcode::
  27. import numpy as np
  28. from megengine import tensor
  29. import megengine.functional as F
  30. data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  31. out = F.broadcast_to(data, (4, 2, 3))
  32. print(out.numpy())
  33. Outputs:
  34. .. testoutput::
  35. [[[0. 1. 2.]
  36. [3. 4. 5.]]
  37. [[0. 1. 2.]
  38. [3. 4. 5.]]
  39. [[0. 1. 2.]
  40. [3. 4. 5.]]
  41. [[0. 1. 2.]
  42. [3. 4. 5.]]]
  43. """
  44. if isinstance(shape, int):
  45. shape = (shape,)
  46. return mgb.opr.broadcast(inp, shape)
  47. def _get_idx(index, axis):
  48. index_dims = len(index.imm_shape)
  49. idx = []
  50. comp_node, comp_graph = _decide_comp_node_and_comp_graph(index)
  51. for i in range(index_dims):
  52. if i != axis:
  53. shape = [1] * index_dims
  54. shape[i] = index.axis_shape(i)
  55. arange = mgb.opr.linspace(
  56. 0,
  57. index.axis_shape(i) - 1,
  58. index.axis_shape(i),
  59. comp_node=comp_node,
  60. comp_graph=comp_graph,
  61. )
  62. arange = (
  63. arange.reshape(*shape)
  64. .broadcast(index.shape)
  65. .reshape(-1)
  66. .astype(np.int32)
  67. )
  68. idx.append(arange)
  69. else:
  70. idx.append(index.reshape(-1))
  71. return tuple(idx)
  72. @wrap_io_tensor
  73. def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
  74. r"""
  75. Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`.
  76. For a 3-D tensor, the output is specified by::
  77. out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
  78. out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
  79. out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
  80. if :attr:`inp` is an n-dimensional tensor with size
  81. :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
  82. then :attr:`index` must be an n-dimensional tensor with size
  83. :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
  84. output will have the same size as :attr:`index`.
  85. :param inp: the source tensor
  86. :param axis: the axis along which to index
  87. :param index: the indices of elements to gather
  88. Examples:
  89. .. testcode::
  90. import megengine.functional as F
  91. from megengine.core import tensor
  92. inp = tensor([
  93. [1,2], [3,4], [5,6],
  94. ])
  95. index = tensor([[0,2], [1,0]])
  96. oup = F.gather(inp, 0, index)
  97. print(oup.numpy())
  98. Outputs:
  99. .. testoutput::
  100. [[1 6]
  101. [3 2]]
  102. """
  103. input_shape = inp.imm_shape
  104. index_shape = index.imm_shape
  105. input_dims = len(input_shape)
  106. index_dims = len(index_shape)
  107. if input_dims != index_dims:
  108. raise ValueError(
  109. "The index tensor must have same dimensions as input tensor, "
  110. "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
  111. )
  112. if axis < 0 or axis >= input_dims:
  113. raise ValueError(
  114. "Index axis {} is output of bounds, should in range [0 {})".format(
  115. axis, input_dims
  116. )
  117. )
  118. for i in range(input_dims):
  119. if i != axis and input_shape[i] != index_shape[i]:
  120. raise ValueError(
  121. "The input {} and index {} must have the same size apart from axis {}".format(
  122. input_shape, index_shape, axis
  123. )
  124. )
  125. idx = _get_idx(index, axis)
  126. return mgb.opr.advanced_indexing(inp)[idx].reshape(
  127. index.shape
  128. ) # pylint: disable=no-member
  129. @wrap_io_tensor
  130. def concat(
  131. inps: Iterable[Tensor],
  132. axis: int = 0,
  133. device: Optional[CompNode] = None,
  134. comp_graph: Optional[CompGraph] = None,
  135. ) -> Tensor:
  136. r"""
  137. Concat some tensors
  138. :param inps: Input tensors to concat
  139. :param axis: the dimension over which the tensors are concatenated,
  140. default to 0
  141. :param device: The comp node output on, default to None
  142. :param comp_graph: The graph in which output is, default to None
  143. :return: The output tensor
  144. Examples:
  145. .. testcode::
  146. import numpy as np
  147. from megengine import tensor
  148. import megengine.functional as F
  149. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  150. data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  151. out = F.concat([data1, data2])
  152. print(out.numpy())
  153. Outputs:
  154. .. testoutput::
  155. [[ 0. 1. 2.]
  156. [ 3. 4. 5.]
  157. [ 6. 7. 8.]
  158. [ 9. 10. 11.]]
  159. """
  160. # Output buffer not supported
  161. return mgb.opr.concat(
  162. *list(inps), axis=axis, comp_node=device, comp_graph=comp_graph
  163. )
  164. @wrap_io_tensor
  165. def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
  166. r"""
  167. Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor.
  168. For each value in :attr:`source`, its output index is specified by its index
  169. in :attr:`source` for ``axis != dimension`` and by the corresponding value in
  170. :attr:`index` for ``axis = dimension``.
  171. For a 3-D tensor, :attr:`inp` is updated as::
  172. inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0
  173. inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1
  174. inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2
  175. :attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions.
  176. It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
  177. for all dimensions ``d``.
  178. Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
  179. :param inp: the inp tensor which to be scattered
  180. :param axis: the axis along which to index
  181. :param index: the indices of elements to scatter
  182. :param source: the source element(s) to scatter
  183. Examples:
  184. .. testcode::
  185. import numpy as np
  186. import megengine.functional as F
  187. from megengine.core import tensor
  188. inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
  189. source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
  190. index = tensor([[0,2,0,2,1],[2,0,0,1,2]])
  191. oup = F.scatter(inp, 0, index,source)
  192. print(oup.numpy())
  193. Outputs:
  194. .. testoutput::
  195. :options: +SKIP
  196. [[0.9935 0.0718 0.5939 0. 0. ]
  197. [0. 0. 0. 0.357 0.4396]
  198. [0.7723 0.9465 0. 0.8926 0.4576]]
  199. """
  200. input_shape = inp.imm_shape
  201. index_shape = index.imm_shape
  202. source_shape = source.imm_shape
  203. input_dims = len(input_shape)
  204. index_dims = len(index_shape)
  205. source_dims = len(source_shape)
  206. if input_dims != index_dims or input_dims != source_dims:
  207. raise ValueError("The input, source and index tensor must have same dimensions")
  208. if axis < 0 or axis >= input_dims:
  209. raise ValueError(
  210. "Index axis {} is output of bounds, should in range [0 {})".format(
  211. axis, input_dims
  212. )
  213. )
  214. for i in range(source_dims):
  215. if source_shape[i] > input_shape[i]:
  216. raise ValueError(
  217. "The each shape size for source {} must be less than or equal to input {} ".format(
  218. source_shape, input_shape
  219. )
  220. )
  221. for i in range(index_dims):
  222. if index_shape[i] != source_shape[i]:
  223. raise ValueError(
  224. "The each shape size for index {} must be equal to source {} ".format(
  225. index_shape, source_shape
  226. )
  227. )
  228. for i in range(index_dims):
  229. if i != axis and index_shape[i] > input_shape[i]:
  230. raise ValueError(
  231. "The index {} must be less than or equal to input {} size apart from axis {}".format(
  232. index_shape, input_shape, axis
  233. )
  234. )
  235. idx = _get_idx(index, axis)
  236. return mgb.opr.set_advanced_indexing(inp, source.flatten())[idx]
  237. @wrap_io_tensor
  238. def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
  239. r"""
  240. Select elements either from Tensor x or Tensor y, according to mask.
  241. .. math::
  242. \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i
  243. :param mask: a mask used for choosing x or y
  244. :param x: the first choice
  245. :param y: the second choice
  246. Examples:
  247. .. testcode::
  248. from megengine import tensor
  249. import megengine.functional as F
  250. mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32))
  251. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  252. dtype=np.float32))
  253. y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
  254. out = F.where(mask, x, y)
  255. print(out.numpy())
  256. Outputs:
  257. .. testoutput::
  258. [[1. 6.]
  259. [7. 4.]]
  260. """
  261. v0, index0 = mgb.opr.cond_take(
  262. x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=1
  263. )
  264. v1, index1 = mgb.opr.cond_take(
  265. y, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=0
  266. )
  267. out = x.flatten()
  268. out = mgb.opr.set_advanced_indexing(out, v0)[index0]
  269. out = mgb.opr.set_advanced_indexing(out, v1)[index1]
  270. out = out.reshape(x.shape)
  271. return out
  272. def shapeof(x: Tensor, axis=None):
  273. r"""
  274. The shape of input tensor.
  275. """
  276. return x.shapeof(axis=axis)
  277. @wrap_io_tensor
  278. def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  279. r"""
  280. Swap shapes and strides according to given pattern
  281. :param inp: Input tensor
  282. :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
  283. * (``'x'``) -> make a 0d (scalar) into a 1d vector
  284. * (0, 1) -> identity for 2d vectors
  285. * (1, 0) -> inverts the first and second dimensions
  286. * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
  287. * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
  288. * (2, 0, 1) -> AxBxC to CxAxB
  289. * (0, ``'x'``, 1) -> AxB to Ax1xB
  290. * (1, ``'x'``, 0) -> AxB to Bx1xA
  291. * (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A)
  292. :return: The output tensor
  293. Examples:
  294. .. testcode::
  295. import numpy as np
  296. from megengine import tensor
  297. import megengine.functional as F
  298. x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
  299. out = F.dimshuffle(x, (1, 0))
  300. print(out.numpy())
  301. Outputs:
  302. .. testoutput::
  303. [[1 0]
  304. [1 0]]
  305. """
  306. return mgb.opr.dimshuffle(inp, pattern)
  307. @wrap_io_tensor
  308. def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
  309. r"""
  310. Reshape a tensor to given target shape; total number of logical elements must
  311. remain unchanged
  312. :param inp: Input tensor
  313. :param target_shape: target shape, the components would be concatenated to form the
  314. target shape, and it can contain an element of -1 representing unspec_axis.
  315. Examples:
  316. .. testcode::
  317. import numpy as np
  318. from megengine import tensor
  319. import megengine.functional as F
  320. x = tensor(np.arange(12, dtype=np.int32))
  321. out = F.reshape(x, (3, 2, 2))
  322. print(out.numpy())
  323. Outputs:
  324. .. testoutput::
  325. [[[ 0 1]
  326. [ 2 3]]
  327. [[ 4 5]
  328. [ 6 7]]
  329. [[ 8 9]
  330. [10 11]]]
  331. """
  332. return mgb.opr.reshape(inp, target_shape)
  333. @functools.wraps(dimshuffle)
  334. def transpose(*args, **kwargs):
  335. r"""See :func:`dimshuffle`
  336. """
  337. return dimshuffle(*args, **kwargs)
  338. @wrap_io_tensor
  339. def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
  340. r"""
  341. Add dimension(s) before given axis/axes
  342. :param inp: Input tensor
  343. :param axis: Place(s) of new axes
  344. :return: The output tensor
  345. Examples:
  346. .. testcode::
  347. import numpy as np
  348. from megengine import tensor
  349. import megengine.functional as F
  350. x = tensor([1, 2])
  351. out = F.add_axis(x, (0, 2))
  352. print(out.shape)
  353. Outputs:
  354. .. testoutput::
  355. (1, 2, 1)
  356. """
  357. return mgb.opr.add_axis(inp, axis)
  358. @wrap_io_tensor
  359. def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
  360. r"""
  361. Remove dimension(s) of shape 1
  362. :param inp: Input tensor
  363. :param axis: Place(s) of axes to be removed
  364. :return: The output tensor
  365. Examples:
  366. .. testcode::
  367. import numpy as np
  368. from megengine import tensor
  369. import megengine.functional as F
  370. x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  371. out = F.remove_axis(x, (0, 0, 1))
  372. print(out.shape)
  373. Outputs:
  374. .. testoutput::
  375. (2,)
  376. """
  377. return mgb.opr.remove_axis(inp, axis)
  378. def linspace(
  379. start: Union[int, float, Tensor],
  380. stop: Union[int, float, Tensor],
  381. num: int = 100,
  382. dtype=np.float32,
  383. device: Optional[CompNode] = None,
  384. comp_graph: Optional[CompGraph] = None,
  385. ) -> Tensor:
  386. r"""
  387. Return equally spaced numbers over a specified interval
  388. :param start: Starting value of the squence, shoule be scalar
  389. :param stop: The last value of the squence, shoule be scalar
  390. :param num: number of values to generate
  391. :param dtype: result data type
  392. :return: The generated tensor
  393. Examples:
  394. .. testcode::
  395. import numpy as np
  396. import megengine.functional as F
  397. a = F.linspace(3,10,5)
  398. print(a.numpy())
  399. .. testoutput::
  400. [ 3. 4.75 6.5 8.25 10. ]
  401. """
  402. if dtype is not np.float32:
  403. raise ValueError("linspace is only implemented for float32")
  404. device, comp_graph = _use_default_if_none(device, comp_graph)
  405. ret = Tensor(
  406. mgb.opr.linspace(start, stop, num, comp_node=device, comp_graph=comp_graph)
  407. )
  408. return ret.astype(dtype)
  409. def zeros_like(inp: Tensor) -> Tensor:
  410. r"""
  411. Returns a zero tensor with the same shape as input tensor
  412. :param inp: input tensor
  413. Examples:
  414. .. testcode::
  415. import numpy as np
  416. from megengine import tensor
  417. import megengine.functional as F
  418. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  419. out = F.zeros_like(inp)
  420. print(out.numpy())
  421. .. testoutput::
  422. [[0 0 0]
  423. [0 0 0]]
  424. """
  425. return zeros(inp.shapeof()).astype(inp.dtype)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)