You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 41 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Iterable, Optional, Sequence, Tuple, Union
  10. import numpy as np
  11. from ..core._imperative_rt import CompNode
  12. from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion
  13. from ..core._wrap import as_device
  14. from ..core.ops import builtin
  15. from ..core.ops.builtin import Copy, Identity
  16. from ..core.ops.special import Const
  17. from ..core.tensor.array_method import _broadcast, _remove_axis
  18. from ..core.tensor.utils import astensor1d, convert_inputs, get_device
  19. from ..device import get_default_device
  20. from ..tensor import Tensor
  21. from .elemwise import ceil
  22. __all__ = [
  23. "arange",
  24. "broadcast_to",
  25. "concat",
  26. "cond_take",
  27. "cumsum",
  28. "diag",
  29. "expand_dims",
  30. "eye",
  31. "flatten",
  32. "full",
  33. "full_like",
  34. "gather",
  35. "linspace",
  36. "ones",
  37. "ones_like",
  38. "repeat",
  39. "reshape",
  40. "roll",
  41. "split",
  42. "squeeze",
  43. "stack",
  44. "scatter",
  45. "tile",
  46. "copy",
  47. "transpose",
  48. "where",
  49. "zeros",
  50. "zeros_like",
  51. ]
  52. def diag(inp, k=0) -> Tensor:
  53. r"""If ``inp`` is a 1D tensor, then returns a 2D tensor with the elements of ``inp`` as the diagonal.
  54. If ``inp`` is a 2D tensor, then returns a 1D tensor with the diagonal elements of ``inp``.
  55. Args:
  56. inp: input tensor.
  57. k: diagonal in consider. Use :math:`k=0` for the main diagonal, :math:`k>0` for diagonals above the
  58. main diagonal, and :math:`k<0` for diagonals below the main diagonal. Default: 0.
  59. Returns:
  60. the extracted diagonal or constructed diagonal array.
  61. Examples:
  62. >>> inp = F.arange(6, dtype='int32').reshape(2,3)
  63. >>> out = F.diag(inp, k=1)
  64. >>> out
  65. Tensor([1 5], dtype=int32, device=xpux:0)
  66. >>> F.diag(out)
  67. Tensor([[1 0]
  68. [0 5]], dtype=int32, device=xpux:0)
  69. """
  70. op = builtin.Diag(k=k)
  71. (result,) = apply(op, inp)
  72. return result
  73. def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
  74. r"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
  75. Args:
  76. shape: a list, tuple or integer defining the shape of the output tensor.
  77. dtype: the desired data type of the output tensor. Default: ``float32``.
  78. device: the desired device of the output tensor. Default: if ``None``,
  79. use the default device (see :func:`~.megengine.get_default_device`).
  80. Returns:
  81. eye matrix.
  82. Examples:
  83. .. testcode::
  84. import numpy as np
  85. import megengine.functional as F
  86. out = F.eye(4, 6, dtype=np.float32)
  87. print(out.numpy())
  88. Outputs:
  89. .. testoutput::
  90. [[1. 0. 0. 0. 0. 0.]
  91. [0. 1. 0. 0. 0. 0.]
  92. [0. 0. 1. 0. 0. 0.]
  93. [0. 0. 0. 1. 0. 0.]]
  94. """
  95. if M is not None:
  96. if isinstance(N, Tensor) or isinstance(M, Tensor):
  97. shape = astensor1d((N, M))
  98. else:
  99. shape = Tensor([N, M], dtype="int32", device=device)
  100. elif isinstance(N, Tensor):
  101. shape = N
  102. else:
  103. shape = Tensor(N, dtype="int32", device=device)
  104. op = builtin.Eye(k=0, dtype=dtype, comp_node=device)
  105. (result,) = apply(op, shape)
  106. return result
  107. def full(
  108. shape: Union[int, tuple, list],
  109. value: Union[bool, int, float, Tensor],
  110. dtype=None,
  111. device=None,
  112. ) -> Tensor:
  113. r"""Creates a tensor of shape ``shape`` filled with ``value``.
  114. Args:
  115. shape: output tensor shape.
  116. value: fill value.
  117. dtype: output tensor data type. If ``dtype`` is ``None``, the output tensor
  118. data type must be inferred from ``value``. If the value is an ``int``,
  119. the output tensor data type must be the default integer data type. If the
  120. value is a ``float``, the output tensor data type must be the default
  121. floating-point data type. If the value is a ``bool``, the output tensor
  122. must have boolean data type. Default: ``None``.
  123. device: device on which to place the created tensor. Default: ``None``.
  124. Returns:
  125. a tensor where every element is equal to ``value``.
  126. Examples:
  127. .. testcode::
  128. import numpy as np
  129. import megengine.functional as F
  130. out = F.full([2,3], 1.5)
  131. print(out.numpy())
  132. Outputs:
  133. .. testoutput::
  134. [[1.5 1.5 1.5]
  135. [1.5 1.5 1.5]]
  136. """
  137. if isinstance(shape, int):
  138. shape = (shape,)
  139. if device is None:
  140. device = get_default_device()
  141. (x,) = Const(value, dtype=dtype, device=device)()
  142. if type(shape) in (list, tuple) and len(shape) == 0:
  143. return x
  144. return broadcast_to(x, shape)
  145. def ones(
  146. shape: Union[int, Tuple[int, ...]],
  147. *,
  148. dtype="float32",
  149. device: Optional[CompNode] = None
  150. ) -> Tensor:
  151. r"""Returns a new tensor having a specified shape and filled with ones.
  152. Args:
  153. shape (int or sequence of ints): the shape of the output tensor.
  154. Keyword args:
  155. dtype (:attr:`.Tensor.dtype`): output tensor data type. Default: ``float32``.
  156. device (:attr:`.Tensor.device`): device on which to place the created tensor. Default: ``None``.
  157. Returns:
  158. a tensor containing ones.
  159. Examples:
  160. .. testcode::
  161. import megengine.functional as F
  162. out = F.ones(5)
  163. print(out.numpy())
  164. out = F.ones((5, ), dtype='int32')
  165. print(out.numpy())
  166. out = F.ones((2, 2))
  167. print(out.numpy())
  168. out = F.ones([2, 1])
  169. print(out.numpy())
  170. Outputs:
  171. .. testoutput::
  172. [1. 1. 1. 1. 1.]
  173. [1 1 1 1 1]
  174. [[1. 1.]
  175. [1. 1.]]
  176. [[1.]
  177. [1.]]
  178. """
  179. return full(shape, 1.0, dtype=dtype, device=device)
  180. def zeros(
  181. shape: Union[int, Tuple[int, ...]],
  182. *,
  183. dtype="float32",
  184. device: Optional[CompNode] = None
  185. ) -> Tensor:
  186. r"""Returns a new tensor having a specified shape and filled with zeros.
  187. Args:
  188. shape (int or sequence of ints): the shape of the output tensor.
  189. Keyword args:
  190. dtype (:attr:`.Tensor.dtype`): output tensor data type. Default: ``float32``.
  191. device (:attr:`.Tensor.device`): device on which to place the created tensor. Default: ``None``.
  192. Returns:
  193. a tensor containing zeros.
  194. Examples:
  195. >>> F.zeros((2, 1))
  196. Tensor([[0.]
  197. [0.]], device=xpux:0)
  198. """
  199. return full(shape, 0.0, dtype=dtype, device=device)
  200. def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
  201. r"""Returns a tensor filled with zeros with the same shape and data type as input tensor.
  202. Args:
  203. inp (Tensor): input tensor.
  204. Return:
  205. a tensor containing zeros.
  206. Examples:
  207. >>> input = F.arange(9, dtype='int32').reshape(3,3)
  208. >>> F.zeros_like(input)
  209. Tensor([[0 0 0]
  210. [0 0 0]
  211. [0 0 0]], dtype=int32, device=xpux:0)
  212. """
  213. return full_like(inp, 0.0)
  214. def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
  215. r"""Returns a tensor filled with ones with the same shape and data type as input tensor.
  216. Args:
  217. inp (Tensor): input tensor.
  218. Return:
  219. a tensor containing ones.
  220. Examples:
  221. >>> input = F.arange(6, dtype='int32').reshape(2,3)
  222. >>> F.ones_like(input)
  223. Tensor([[1 1 1]
  224. [1 1 1]], dtype=int32, device=xpux:0)
  225. """
  226. return full_like(inp, 1.0)
  227. def full_like(
  228. inp: Union[Tensor, SymbolVar], value: Union[int, float]
  229. ) -> Union[Tensor, SymbolVar]:
  230. r"""Returns a tensor filled with given value with the same shape as input tensor.
  231. Args:
  232. inp: input tensor.
  233. value: target value.
  234. Return:
  235. output tensor.
  236. Examples:
  237. .. testcode::
  238. import numpy as np
  239. from megengine import tensor
  240. import megengine.functional as F
  241. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  242. out = F.full_like(inp, 2)
  243. print(out.numpy())
  244. Outputs:
  245. .. testoutput::
  246. [[2 2 2]
  247. [2 2 2]]
  248. """
  249. (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
  250. if inp.ndim == 0:
  251. return x
  252. return broadcast_to(x, inp.shape)
  253. def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
  254. r"""Broadcasts a tensor to given shape.
  255. Args:
  256. inp: input tensor.
  257. shape: target shape.
  258. Returns:
  259. output tensor.
  260. Examples:
  261. .. testcode::
  262. import numpy as np
  263. from megengine import tensor
  264. import megengine.functional as F
  265. data = tensor(np.arange(0, 3, dtype=np.float32).reshape(3))
  266. out = F.broadcast_to(data, (2, 3))
  267. print(out.numpy())
  268. Outputs:
  269. .. testoutput::
  270. [[0. 1. 2.]
  271. [0. 1. 2.]]
  272. """
  273. return _broadcast(inp, shape)
  274. def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
  275. r"""Concat some tensors
  276. Args:
  277. inps: input tensors to concat.
  278. axis: over which dimension the tensors are concatenated. Default: 0
  279. device: which device output will be. Default: None
  280. Returns:
  281. output tensor.
  282. Examples:
  283. .. testcode::
  284. import numpy as np
  285. from megengine import tensor
  286. import megengine.functional as F
  287. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  288. data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  289. out = F.concat([data1, data2])
  290. print(out.numpy())
  291. Outputs:
  292. .. testoutput::
  293. [[ 0. 1. 2.]
  294. [ 3. 4. 5.]
  295. [ 6. 7. 8.]
  296. [ 9. 10. 11.]]
  297. """
  298. if len(inps) == 1:
  299. return inps[0]
  300. # FIXME: remove this convert_inputs
  301. inps = convert_inputs(*inps, device=device)
  302. if device is None:
  303. device = get_device(inps)
  304. device = as_device(device)
  305. (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
  306. return result
  307. def stack(inps, axis=0, device=None):
  308. r"""Concats a sequence of tensors along a new axis.
  309. The input tensors must have the same shape.
  310. Args:
  311. inps: input tensors.
  312. axis: which axis will be concatenated.
  313. device: the device output will be. Default: None
  314. Returns:
  315. output concatenated tensor.
  316. Examples:
  317. .. testcode::
  318. import numpy as np
  319. from megengine import tensor
  320. import megengine.functional as F
  321. x1 = tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
  322. x2 = tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
  323. out = F.stack([x1, x2], axis=0)
  324. print(out.numpy())
  325. Outputs:
  326. .. testoutput::
  327. [[0. 1. 2.]
  328. [6. 7. 8.]]
  329. """
  330. if len(inps) > 0 and not isinstance(inps[0].shape, inps[0].__class__):
  331. shapes = {arr.shape for arr in inps}
  332. if len(shapes) != 1:
  333. raise ValueError("All input tensors must have the same shape")
  334. inps = [expand_dims(inp, axis=axis) for inp in inps]
  335. return concat(inps, axis=axis, device=device)
  336. def split(inp, nsplits_or_sections, axis=0):
  337. r"""Splits the input tensor into several smaller tensors.
  338. When nsplits_or_sections is int, the last tensor may be smaller than others.
  339. Args:
  340. inp: input tensor.
  341. nsplits_or_sections: number of sub tensors or sections information list.
  342. axis: which axis will be splited.
  343. Returns:
  344. output tensor list.
  345. Examples:
  346. .. testcode::
  347. import os
  348. import numpy as np
  349. from megengine import tensor
  350. import megengine.functional as F
  351. x = tensor(np.random.random((10, 20)), dtype=np.float32)
  352. y = F.split(x, 3)
  353. z = F.split(x, [6, 17], axis=1)
  354. print([i.numpy().shape for i in y])
  355. print([i.numpy().shape for i in z])
  356. Outputs:
  357. .. testoutput::
  358. [(4, 20), (3, 20), (3, 20)]
  359. [(10, 6), (10, 11), (10, 3)]
  360. """
  361. ndim = len(inp.shape)
  362. if axis >= ndim:
  363. raise ValueError("Invalid axis {}".format(axis))
  364. Ntotal = inp.shape[axis]
  365. if isinstance(nsplits_or_sections, Sequence):
  366. Nsections = len(nsplits_or_sections) + 1
  367. is_array = True
  368. else:
  369. Nsections = int(nsplits_or_sections)
  370. is_array = False
  371. if is_array:
  372. partitions = []
  373. div_points = [0] + list(nsplits_or_sections) + [Ntotal]
  374. for i in range(1, len(div_points)):
  375. if div_points[i - 1] > div_points[i]:
  376. raise ValueError(
  377. "Invalid nsplits_or_secions: {}".format(nsplits_or_sections)
  378. )
  379. partitions.append(div_points[i] - div_points[i - 1])
  380. else: # scalar
  381. if Nsections <= 0:
  382. raise ValueError("Number sections must be larger than 0")
  383. if Nsections > Ntotal:
  384. raise ValueError(
  385. "The size {} at dim {} cannot be split into {} sections".format(
  386. Ntotal, axis, Nsections
  387. )
  388. )
  389. partitions = []
  390. for i in range(Nsections):
  391. section_size = (Ntotal + Nsections - i - 1) // Nsections
  392. partitions.append(section_size)
  393. partitions = [
  394. part
  395. if isinstance(part, (SymbolVar, Tensor))
  396. else Const(part, dtype="int32", device=inp.device)(inp)[0]
  397. for part in partitions
  398. ]
  399. op = builtin.Split(axis=axis)
  400. return apply(op, inp, *partitions)
  401. def _get_idx(index, axis):
  402. index_dims = len(index.shape)
  403. idx = []
  404. for i in range(index_dims):
  405. if i != axis:
  406. shape = [1] * index_dims
  407. shape[i] = index.shape[i]
  408. arange = linspace(
  409. 0, index.shape[i] - 1, index.shape[i], device=index.device,
  410. )
  411. arange = (
  412. broadcast_to(arange.reshape(*shape), index.shape)
  413. .reshape(-1)
  414. .astype(np.int32)
  415. )
  416. idx.append(arange)
  417. else:
  418. idx.append(index.reshape(-1))
  419. return tuple(idx)
  420. def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
  421. # TODO: rewrite doc
  422. r"""
  423. Gathers data from input tensor on axis using index.
  424. For a 3-D tensor, the output is specified by:
  425. .. code-block::
  426. out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
  427. out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
  428. out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
  429. if input tensor is a n-dimensional tensor with size
  430. :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
  431. then index must be a n-dimensional tensor with size
  432. :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
  433. output will have the same size as index.
  434. Args:
  435. inp: input tensor.
  436. axis: along which axis to index.
  437. index: indices of elements to gather.
  438. Return:
  439. output tensor.
  440. Examples:
  441. .. testcode::
  442. import megengine.functional as F
  443. from megengine import tensor
  444. inp = tensor([
  445. [1,2], [3,4], [5,6],
  446. ])
  447. index = tensor([[0,2], [1,0]])
  448. oup = F.gather(inp, 0, index)
  449. print(oup.numpy())
  450. Outputs:
  451. .. testoutput::
  452. [[1 6]
  453. [3 2]]
  454. """
  455. input_shape = inp.shape
  456. index_shape = index.shape
  457. input_dims = len(input_shape)
  458. index_dims = len(index_shape)
  459. if input_dims != index_dims:
  460. raise ValueError(
  461. "The index tensor must have same dimensions as input tensor, "
  462. "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
  463. )
  464. if axis < 0 or axis >= input_dims:
  465. raise ValueError(
  466. "Index axis {} is output of bounds, should in range [0 {})".format(
  467. axis, input_dims
  468. )
  469. )
  470. for i in range(input_dims):
  471. if i != axis and input_shape[i] != index_shape[i]:
  472. raise ValueError(
  473. "The input {} and index {} must have the same size apart from axis {}".format(
  474. input_shape, index_shape, axis
  475. )
  476. )
  477. idx = _get_idx(index, axis)
  478. return inp[idx].reshape(index.shape) # pylint: disable=no-member
  479. def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
  480. # TODO: rewrite doc
  481. r"""
  482. Writes all values from the tensor source into input tensor
  483. at the indices specified in the index tensor.
  484. For each value in source, its output index is specified by its index
  485. in source for ``axis != dimension`` and by the corresponding value in
  486. index for ``axis = dimension``.
  487. For a 3-D tensor, input tensor is updated as:
  488. .. code-block::
  489. inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0
  490. inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1
  491. inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2
  492. ``inp``, ``index`` and ``source`` should have same number of dimensions.
  493. It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
  494. for all dimensions ``d``.
  495. Moreover, the values of index must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
  496. Note:
  497. Please notice that, due to performance issues, the result is uncertain on the GPU device
  498. if scattering different positions from source to the same destination position
  499. regard to index tensor.
  500. Check the following examples, the oup[0][2] is maybe
  501. from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339
  502. if set the index[1][2] from 1 to 0.
  503. Args:
  504. inp: inp tensor which to be scattered.
  505. axis: axis along which to index.
  506. index: indices of elements to scatter.
  507. source: source element(s) to scatter.
  508. Return:
  509. output tensor.
  510. Examples:
  511. .. testcode::
  512. import numpy as np
  513. import megengine.functional as F
  514. from megengine import tensor
  515. inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
  516. source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
  517. index = tensor([[0,2,0,2,1],[2,0,1,1,2]])
  518. oup = F.scatter(inp, 0, index,source)
  519. print(oup.numpy())
  520. Outputs:
  521. .. testoutput::
  522. [[0.9935 0.0718 0.2256 0. 0. ]
  523. [0. 0. 0.5939 0.357 0.4396]
  524. [0.7723 0.9465 0. 0.8926 0.4576]]
  525. """
  526. input_shape = inp.shape
  527. index_shape = index.shape
  528. source_shape = source.shape
  529. input_dims = len(input_shape)
  530. index_dims = len(index_shape)
  531. source_dims = len(source_shape)
  532. if input_dims != index_dims or input_dims != source_dims:
  533. raise ValueError("The input, source and index tensor must have same dimensions")
  534. if axis < 0 or axis >= input_dims:
  535. raise ValueError(
  536. "Index axis {} is output of bounds, should in range [0 {})".format(
  537. axis, input_dims
  538. )
  539. )
  540. for i in range(source_dims):
  541. if source_shape[i] > input_shape[i]:
  542. raise ValueError(
  543. "The each shape size for source {} must be less than or equal to input {} ".format(
  544. source_shape, input_shape
  545. )
  546. )
  547. for i in range(index_dims):
  548. if index_shape[i] != source_shape[i]:
  549. raise ValueError(
  550. "The each shape size for index {} must be equal to source {} ".format(
  551. index_shape, source_shape
  552. )
  553. )
  554. for i in range(index_dims):
  555. if i != axis and index_shape[i] > input_shape[i]:
  556. raise ValueError(
  557. "The index {} must be less than or equal to input {} size apart from axis {}".format(
  558. index_shape, input_shape, axis
  559. )
  560. )
  561. idx = _get_idx(index, axis)
  562. inp[idx] = source.flatten()
  563. return inp
  564. def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
  565. r"""Selects elements either from Tensor x or Tensor y, according to mask.
  566. .. math::
  567. \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i
  568. Args:
  569. mask: a mask used for choosing ``x`` or ``y``.
  570. x: first choice.
  571. y: second choice.
  572. Returns:
  573. output tensor.
  574. Examples:
  575. .. testcode::
  576. import numpy as np
  577. from megengine import tensor
  578. import megengine.functional as F
  579. mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool))
  580. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  581. dtype=np.float32))
  582. y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
  583. out = F.where(mask, x, y)
  584. print(out.numpy())
  585. Outputs:
  586. .. testoutput::
  587. [[1. 6.]
  588. [7. 4.]]
  589. """
  590. if not isinstance(x, Tensor):
  591. raise TypeError("input x must be a tensor")
  592. if not isinstance(y, Tensor):
  593. raise TypeError("input y must be a tensor")
  594. if not isinstance(mask, Tensor):
  595. raise TypeError("mask must be a tensor")
  596. if mask.dtype != np.bool_:
  597. raise ValueError("mask must be bool")
  598. if x.device != mask.device:
  599. raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
  600. dtype = dtype_promotion(x, y)
  601. if x.dtype != dtype:
  602. x = x.astype(dtype)
  603. if y.dtype != dtype:
  604. y = y.astype(dtype)
  605. v0, index0 = cond_take(mask, x)
  606. v1, index1 = cond_take(~mask, y)
  607. out = concat([v0, v1])
  608. out[index0] = v0
  609. out[index1] = v1
  610. out = out.reshape(x.shape)
  611. return out
  612. def cond_take(mask: Tensor, x: Tensor) -> Tensor:
  613. r"""Takes elements from data if specific condition is satisfied on mask.
  614. This operator has two outputs: the first is the elements taken,
  615. and the second is the indices corresponding to those elements;
  616. they are both 1-dimensional. High-dimension input would first be flattened.
  617. Args:
  618. mask: condition param; must be the same shape with data.
  619. x: input tensor from which to take elements.
  620. Examples:
  621. .. testcode::
  622. import numpy as np
  623. from megengine import tensor
  624. import megengine.functional as F
  625. mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
  626. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  627. dtype=np.float32))
  628. v, index = F.cond_take(mask, x)
  629. print(v.numpy(), index.numpy())
  630. Outputs:
  631. .. testoutput::
  632. [1. 4.] [0 3]
  633. """
  634. if not isinstance(x, (Tensor, SymbolVar)):
  635. raise TypeError("input must be a tensor")
  636. if not isinstance(mask, (Tensor, SymbolVar)):
  637. raise TypeError("mask must be a tensor")
  638. if mask.dtype != np.bool_:
  639. raise ValueError("mask must be bool")
  640. if x.device != mask.device:
  641. raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
  642. op = builtin.CondTake()
  643. v, index = apply(op, x, mask)
  644. return v, index
  645. def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  646. r"""Swaps shapes and strides according to given pattern.
  647. Args:
  648. inp: input tensor.
  649. pattern: a list of integers including 0, 1, ... , ``ndim``-1,
  650. and any number of ``'x'`` char in dimensions where this tensor should be broadcasted.
  651. For examples:
  652. * (``'x'``) -> make a 0d (scalar) into a 1d vector
  653. * (0, 1) -> identity for 2d vectors
  654. * (1, 0) -> inverts the first and second dimensions
  655. * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
  656. * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
  657. * (2, 0, 1) -> AxBxC to CxAxB
  658. * (0, ``'x'``, 1) -> AxB to Ax1xB
  659. * (1, ``'x'``, 0) -> AxB to Bx1xA
  660. * (1,) -> this removes dimensions 0. It must be a broadcastable dimension (1xA to A)
  661. Returns:
  662. output tensor.
  663. Examples:
  664. .. testcode::
  665. import numpy as np
  666. from megengine import tensor
  667. import megengine.functional as F
  668. x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
  669. out = F.transpose(x, (1, 0))
  670. print(out.numpy())
  671. Outputs:
  672. .. testoutput::
  673. [[1 0]
  674. [1 0]]
  675. """
  676. return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern))
  677. def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
  678. r"""Reshapes a tensor without changing its data.
  679. Args:
  680. inp: input tensor to reshape.
  681. target_shape: target shape compatible with the original shape. One shape dimension is allowed
  682. to be `-1` . When a shape dimension is `-1` , the corresponding output tensor shape dimension
  683. must be inferred from the length of the tensor and the remaining dimensions.
  684. Returns:
  685. an output tensor having the same data type, elements, and underlying element order as `inp` .
  686. Examples:
  687. >>> x = F.arange(12)
  688. >>> x
  689. Tensor([ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11.], device=xpux:0)
  690. >>> F.reshape(x, (3, 4))
  691. Tensor([[ 0. 1. 2. 3.]
  692. [ 4. 5. 6. 7.]
  693. [ 8. 9. 10. 11.]], device=xpux:0)
  694. >>> F.reshape(x, (2, -1))
  695. Tensor([[ 0. 1. 2. 3. 4. 5.]
  696. [ 6. 7. 8. 9. 10. 11.]], device=xpux:0)
  697. """
  698. return inp.reshape(target_shape)
  699. def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
  700. r"""Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``.
  701. Args:
  702. inp: input tensor.
  703. start_axis: start dimension that the sub-tensor to be flattened. Default: 0
  704. end_axis: end dimension that the sub-tensor to be flattened. Default: -1
  705. Returns:
  706. output tensor.
  707. Examples:
  708. .. testcode::
  709. import numpy as np
  710. from megengine import tensor
  711. import megengine.functional as F
  712. inp_shape = (2, 2, 3, 3)
  713. x = tensor(
  714. np.arange(36, dtype=np.int32).reshape(inp_shape),
  715. )
  716. out = F.flatten(x, 2)
  717. print(x.numpy().shape)
  718. print(out.numpy().shape)
  719. Outputs:
  720. .. testoutput::
  721. (2, 2, 3, 3)
  722. (2, 2, 9)
  723. """
  724. target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,)
  725. if end_axis != -1:
  726. target_shape += (*inp.shape[end_axis + 1 :],)
  727. return inp.reshape(*target_shape)
  728. def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
  729. r"""Adds dimension before given axis.
  730. Args:
  731. inp: input tensor.
  732. axis: place of new axes.
  733. Returns:
  734. output tensor.
  735. Examples:
  736. .. testcode::
  737. import numpy as np
  738. from megengine import tensor
  739. import megengine.functional as F
  740. x = tensor([1, 2])
  741. out = F.expand_dims(x, 0)
  742. print(out.numpy().shape)
  743. Outputs:
  744. .. testoutput::
  745. (1, 2)
  746. """
  747. def get_axes():
  748. try:
  749. return [int(axis)]
  750. except (TypeError, ValueError):
  751. pass
  752. return list(map(int, axis))
  753. axis = get_axes()
  754. try:
  755. ndim = inp.ndim + len(axis)
  756. axis = sorted(i + ndim if i < 0 else i for i in axis)
  757. except ValueError:
  758. if any([ind < 0 for ind in axis]):
  759. raise IndexError(
  760. "Does not support negative index when tensor's ndim is unknown"
  761. )
  762. axis = sorted(axis)
  763. assert axis, "axis could not be empty"
  764. op = builtin.AddAxis(axis=axis)
  765. (result,) = apply(op, inp)
  766. return result
  767. def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor:
  768. r"""Removes dimension of shape 1.
  769. Args:
  770. inp: input tensor.
  771. axis: place of axis to be removed.
  772. Returns:
  773. output tensor.
  774. Examples:
  775. .. testcode::
  776. import numpy as np
  777. from megengine import tensor
  778. import megengine.functional as F
  779. x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  780. out = F.squeeze(x, 3)
  781. print(out.numpy().shape)
  782. Outputs:
  783. .. testoutput::
  784. (1, 1, 2)
  785. """
  786. return _remove_axis(inp, axis)
  787. def linspace(
  788. start: Union[int, float, Tensor],
  789. stop: Union[int, float, Tensor],
  790. num: Union[int, Tensor],
  791. dtype="float32",
  792. device: Optional[CompNode] = None,
  793. ) -> Tensor:
  794. r"""Returns equally spaced numbers over a specified interval.
  795. Args:
  796. start: starting value of the squence, shoule be scalar.
  797. stop: last value of the squence, shoule be scalar.
  798. num: number of values to generate.
  799. dtype: result data type.
  800. Returns:
  801. generated tensor.
  802. Examples:
  803. .. testcode::
  804. import numpy as np
  805. import megengine.functional as F
  806. a = F.linspace(3, 10, 5)
  807. print(a.numpy())
  808. Outputs:
  809. .. testoutput::
  810. [ 3. 4.75 6.5 8.25 10. ]
  811. """
  812. for item in (start, stop, num):
  813. cur_device = getattr(item, "device", None)
  814. if device is None:
  815. device = cur_device
  816. else:
  817. if not (cur_device is None or device == cur_device):
  818. raise ("ambiguous device for linspace opr")
  819. is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num])
  820. if any(is_symbolvar) and not all(is_symbolvar):
  821. raise TypeError("start, stop and num should all be VarNode or none of them")
  822. if not isinstance(start, (Tensor, SymbolVar)):
  823. start = Tensor(start, device=device)
  824. if not isinstance(stop, (Tensor, SymbolVar)):
  825. stop = Tensor(stop, device=device)
  826. if not isinstance(num, (Tensor, SymbolVar)):
  827. num = Tensor(num, device=device)
  828. op = builtin.Linspace(comp_node=device)
  829. (result,) = apply(op, start, stop, num)
  830. if np.dtype(dtype) != np.float32:
  831. return result.astype(dtype)
  832. return result
  833. def arange(
  834. start: Union[int, float, Tensor] = 0,
  835. stop: Optional[Union[int, float, Tensor]] = None,
  836. step: Union[int, float, Tensor] = 1,
  837. dtype="float32",
  838. device: Optional[CompNode] = None,
  839. ) -> Tensor:
  840. r"""Returns evenly spaced values within the half-open interval ``[start, stop)`` as a one-dimensional tensor.
  841. Note:
  842. This function cannot guarantee that the interval does not include the stop value in those cases
  843. where step is not an integer and floating-point rounding errors affect the length of the output tensor.
  844. Args:
  845. start: if ``stop`` is specified, the start of interval (inclusive); otherwise,
  846. the end of the interval (exclusive). If ``stop`` is not specified, the default starting value is ``0``.
  847. stop: the end of the interval. Default: ``None``.
  848. step: the distance between two adjacent elements ( ``out[i+1] - out[i]`` ). Must not be 0 ;
  849. may be negative, this results i an empty tensor if stop >= start . Default: 1 .
  850. Keyword args:
  851. dtype( :attr:`.Tensor.dtype` ): output tensor data type. Default: ``float32``.
  852. device( :attr:`.Tensor.device` ): device on which to place the created tensor. Default: ``None``.
  853. Returns:
  854. A one-dimensional tensor containing evenly spaced values.
  855. The length of the output tensor must be ``ceil((stop-start)/step)``
  856. if ``stop - start`` and ``step`` have the same sign, and length 0 otherwise.
  857. Examples:
  858. >>> F.arange(5)
  859. Tensor([0. 1. 2. 3. 4.], device=xpux:0)
  860. >>> F.arange(1, 4)
  861. Tensor([1. 2. 3.], device=xpux:0)
  862. """
  863. if stop is None:
  864. start, stop = 0, start
  865. start = Tensor(start, dtype="float32")
  866. stop = Tensor(stop, dtype="float32")
  867. step = Tensor(step, dtype="float32")
  868. num = ceil((stop - start) / step)
  869. stop = start + step * (num - 1)
  870. result = linspace(start, stop, num, device=device)
  871. if np.dtype(dtype) != np.float32:
  872. return result.astype(dtype)
  873. return result
  874. def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
  875. r"""Repeat elements of an array.
  876. Args:
  877. inp: input tensor.
  878. repeats: the number of repetitions for each element.
  879. axis: the axis along which to repeat values. By default, use the
  880. flattened input array, and return a flat output array.
  881. Returns:
  882. output tensor.
  883. Examples:
  884. .. testcode::
  885. import numpy as np
  886. import megengine.functional as F
  887. from megengine import tensor
  888. x = tensor([[1, 2], [3, 4]], np.int32)
  889. y = F.repeat(x, 2, axis=0)
  890. print(y.numpy())
  891. Outputs:
  892. .. testoutput::
  893. [[1 2]
  894. [1 2]
  895. [3 4]
  896. [3 4]]
  897. """
  898. if axis is None:
  899. inp = inp.reshape(-1) # flatten
  900. axis = 0
  901. shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
  902. # assume inp.ndim is not changed during trace
  903. max_axis = len(shape) - 1
  904. assert axis >= 0 and axis <= max_axis
  905. assert repeats >= 1
  906. base_shape, bcast_shape, target_shape = [], [], []
  907. if axis != 0:
  908. target_shape.append(shape[:axis])
  909. base_shape.extend([shape[: axis + 1], [1,]])
  910. bcast_shape.extend([shape[: axis + 1], [repeats,]])
  911. target_shape.extend(
  912. [shape[axis] * repeats,]
  913. )
  914. if axis + 1 <= max_axis:
  915. base_shape.append(shape[axis + 1 :])
  916. bcast_shape.append(shape[axis + 1 :])
  917. target_shape.append(shape[axis + 1 :])
  918. out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
  919. concat(target_shape)
  920. )
  921. return out
  922. def _tile_one_dim(inp, rep, axis):
  923. shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
  924. # assume inp.ndim is not changed during trace
  925. max_axis = len(shape) - 1
  926. base_shape, bcast_shape, target_shape = [], [], []
  927. if axis != 0:
  928. base_shape.append(shape[:axis])
  929. bcast_shape.append(shape[:axis])
  930. target_shape.append(shape[:axis])
  931. base_shape.extend([[1,], shape[axis:]])
  932. bcast_shape.extend([rep, shape[axis:]])
  933. target_shape.append(shape[axis] * rep)
  934. if axis + 1 <= max_axis:
  935. target_shape.append(shape[axis + 1 :])
  936. out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
  937. concat(target_shape)
  938. )
  939. return out
  940. def tile(inp: Tensor, reps: Iterable[int]):
  941. r"""Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d,
  942. the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``,
  943. ``inp`` is promoted to be ``d``-dimensional by prepending new axis.
  944. Args:
  945. inp: input tensor.
  946. reps: The number of repetitions of inp along each axis.
  947. Returns:
  948. output tensor.
  949. Examples:
  950. .. testcode::
  951. import numpy as np
  952. import megengine.functional as F
  953. from megengine import tensor
  954. x = tensor([[1, 2], [3, 4]], np.int32)
  955. y = F.tile(x, (2,1))
  956. print(y.numpy())
  957. Outputs:
  958. .. testoutput::
  959. [[1 2]
  960. [3 4]
  961. [1 2]
  962. [3 4]]
  963. """
  964. shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
  965. reps = astensor1d(reps, inp, dtype="int32", device=inp.device)
  966. l_shape = len(shape)
  967. l_reps = len(reps)
  968. assert (
  969. l_reps >= l_shape
  970. ), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor"
  971. for i in range(l_shape):
  972. rep = reps[i + (l_reps - l_shape)]
  973. inp = _tile_one_dim(inp, rep, i)
  974. if l_reps > l_shape:
  975. shape = inp.shape
  976. extra = reps[:-l_shape]
  977. extra_ones = ones_like(extra)
  978. base_shape = concat([extra_ones, shape])
  979. bcast_shape = concat([extra, shape])
  980. target_shape = concat([extra, shape])
  981. inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)
  982. return inp
  983. def copy(inp, device=None):
  984. r"""Copies tensor to another device.
  985. Args:
  986. inp: input tensor.
  987. device: destination device.
  988. Examples:
  989. .. testcode::
  990. import numpy as np
  991. import platform
  992. from megengine import tensor
  993. from megengine.device import get_device_count
  994. import megengine.functional as F
  995. x = tensor([1, 2, 3], np.int32)
  996. if 1 == get_device_count("gpu"):
  997. y = F.copy(x, "cpu1")
  998. print(y.numpy())
  999. else:
  1000. y = F.copy(x, "xpu1")
  1001. print(y.numpy())
  1002. Outputs:
  1003. .. testoutput::
  1004. [1 2 3]
  1005. """
  1006. if device is None:
  1007. return apply(Identity(), inp)[0]
  1008. return apply(Copy(comp_node=as_device(device).to_c()), inp)[0]
  1009. def roll(
  1010. inp: Tensor,
  1011. shift: Union[int, Iterable[int]],
  1012. axis: Optional[Union[int, Iterable[int]]] = None,
  1013. ):
  1014. r"""Roll the tensor along the given axis(or axes). Elements that are shifted
  1015. beyond the last position are re-introduced at the first position.
  1016. Args:
  1017. inp: input tensor.
  1018. shift: the number of places by which the elements of the tensor are
  1019. shifted. If shift is a tuple, axis must be a tuple of the same size,
  1020. and each axis will be rolled by the corresponding shift value.
  1021. axis: axis along which to roll. If axis is not specified, the tensor
  1022. will be flattened before rolling and then restored to the original shape.
  1023. Duplicate axes is allowed if it is a tuple. Default: None.
  1024. Examples:
  1025. .. testcode::
  1026. import numpy as np
  1027. from megengine import tensor
  1028. import megengine.functional as F
  1029. x = tensor([[1,2],[3,4],[5,6]], np.int32)
  1030. y = F.roll(x, 1, 0)
  1031. print(y.numpy())
  1032. Outputs:
  1033. .. testoutput::
  1034. [[5 6]
  1035. [1 2]
  1036. [3 4]]
  1037. """
  1038. shp_bak = None
  1039. if axis is None:
  1040. shp_bak = inp.shape
  1041. inp = inp.flatten()
  1042. axis = 0
  1043. shp = inp.shape
  1044. dim = len(shp)
  1045. if isinstance(shift, int):
  1046. assert isinstance(axis, int)
  1047. shift, axis = [shift,], [axis,]
  1048. assert len(shift) == len(axis)
  1049. out = inp
  1050. for i in range(len(shift)):
  1051. axis_ = axis[i]
  1052. shift_ = shift[i]
  1053. axis_normalized_ = axis_ + dim if axis_ < 0 else axis_
  1054. assert (
  1055. dim > axis_normalized_ >= 0
  1056. ), "axis out of range (expected to be in range of [{}, {}], but got {})".format(
  1057. -dim, dim - 1, axis_
  1058. )
  1059. if shift_ == 0:
  1060. continue
  1061. size = shp[axis_normalized_]
  1062. shift_normalized_ = 0 if size == 0 else shift_ % size
  1063. if shift_normalized_ > 0:
  1064. a, b = split(out, [size - shift_normalized_,], axis=axis_normalized_)
  1065. else:
  1066. a, b = split(out, [-shift_normalized_,], axis=axis_normalized_)
  1067. out = concat((b, a), axis=axis_normalized_)
  1068. if shp_bak is not None:
  1069. out = out.reshape(shp_bak)
  1070. return out
  1071. def cumsum(inp: Tensor, axis: int):
  1072. r"""Computes the cumulative sum of elements along given axis.
  1073. Args:
  1074. inp: input tensor.
  1075. axis: axis along which cumsum is performed.
  1076. Examples:
  1077. .. testcode::
  1078. from megengine import tensor
  1079. import megengine.functional as F
  1080. x = tensor([[1, 2, 3], [4, 5, 6]], "int32")
  1081. y = F.cumsum(x, 1)
  1082. print(y.numpy())
  1083. Outputs:
  1084. .. testoutput::
  1085. [[ 1 3 6]
  1086. [ 4 9 15]]
  1087. """
  1088. assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
  1089. assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis)
  1090. op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
  1091. return apply(op, inp)[0]