You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 37 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Iterable, Optional, Sequence, Union
  10. import numpy as np
  11. from ..core._imperative_rt import CompNode
  12. from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion
  13. from ..core._wrap import as_device
  14. from ..core.ops import builtin
  15. from ..core.ops.builtin import Copy, Identity
  16. from ..core.ops.special import Const
  17. from ..core.tensor.array_method import _broadcast, _remove_axis
  18. from ..core.tensor.utils import astensor1d, convert_inputs, get_device
  19. from ..device import get_default_device
  20. from ..tensor import Tensor
  21. from .elemwise import ceil, floor_div
  22. __all__ = [
  23. "arange",
  24. "broadcast_to",
  25. "concat",
  26. "cond_take",
  27. "cumsum",
  28. "expand_dims",
  29. "eye",
  30. "flatten",
  31. "full",
  32. "full_like",
  33. "gather",
  34. "linspace",
  35. "ones",
  36. "ones_like",
  37. "repeat",
  38. "reshape",
  39. "roll",
  40. "split",
  41. "squeeze",
  42. "stack",
  43. "scatter",
  44. "tile",
  45. "copy",
  46. "transpose",
  47. "where",
  48. "zeros",
  49. "zeros_like",
  50. ]
  51. def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
  52. """
  53. Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
  54. :param shape: expected shape of output tensor.
  55. :param dtype: data type. Default: None
  56. :param device: compute node of the matrix. Default: None
  57. :return: eye matrix.
  58. Examples:
  59. .. testcode::
  60. import numpy as np
  61. import megengine.functional as F
  62. out = F.eye(4, 6, dtype=np.float32)
  63. print(out.numpy())
  64. Outputs:
  65. .. testoutput::
  66. [[1. 0. 0. 0. 0. 0.]
  67. [0. 1. 0. 0. 0. 0.]
  68. [0. 0. 1. 0. 0. 0.]
  69. [0. 0. 0. 1. 0. 0.]]
  70. """
  71. if M is not None:
  72. if isinstance(N, Tensor) or isinstance(M, Tensor):
  73. shape = astensor1d((N, M))
  74. else:
  75. shape = Tensor([N, M], dtype="int32", device=device)
  76. elif isinstance(N, Tensor):
  77. shape = N
  78. else:
  79. shape = Tensor(N, dtype="int32", device=device)
  80. op = builtin.Eye(k=0, dtype=dtype, comp_node=device)
  81. (result,) = apply(op, shape)
  82. return result
  83. def full(shape, value, dtype="float32", device=None) -> Tensor:
  84. r"""
  85. Creates a tensor of shape ``shape`` filled with ``value``.
  86. :param shape: a list, tuple or integer defining the shape of the output tensor.
  87. :param value: the value to fill the output tensor with.
  88. :param dtype: the desired data type of the output tensor. Default: ``float32``.
  89. :param device: the desired device of the output tensor. Default: if ``None``, use the default device(see ``megengine.get_default_device()``).
  90. :return: output tensor.
  91. Examples:
  92. .. testcode::
  93. import numpy as np
  94. import megengine.functional as F
  95. out = F.full([2,3], 1.5)
  96. print(out.numpy())
  97. Outputs:
  98. .. testoutput::
  99. [[1.5 1.5 1.5]
  100. [1.5 1.5 1.5]]
  101. """
  102. if isinstance(shape, int):
  103. shape = (shape,)
  104. if device is None:
  105. device = get_default_device()
  106. (x,) = Const(value, dtype=dtype, device=device)()
  107. if shape is (): # scalar.shape
  108. return x
  109. return broadcast_to(x, shape)
  110. def ones(shape, dtype="float32", device=None) -> Tensor:
  111. """
  112. Returns a ones tensor with given shape.
  113. :param inp: input tensor.
  114. :return: output zero tensor.
  115. Examples:
  116. .. testcode::
  117. import megengine.functional as F
  118. out = F.ones((2, 1))
  119. print(out.numpy())
  120. Outputs:
  121. .. testoutput::
  122. [[1.]
  123. [1.]]
  124. """
  125. return full(shape, 1.0, dtype=dtype, device=device)
  126. def zeros(shape, dtype="float32", device=None) -> Tensor:
  127. """
  128. Returns a zero tensor with given shape.
  129. """
  130. return full(shape, 0.0, dtype=dtype, device=device)
  131. def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
  132. """
  133. Returns a zero tensor with the same shape as input tensor.
  134. :param inp: input tensor.
  135. :return: output zero tensor.
  136. Examples:
  137. .. testcode::
  138. import numpy as np
  139. from megengine import tensor
  140. import megengine.functional as F
  141. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  142. out = F.zeros_like(inp)
  143. print(out.numpy())
  144. Outputs:
  145. .. testoutput::
  146. [[0 0 0]
  147. [0 0 0]]
  148. """
  149. return full_like(inp, 0.0)
  150. def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
  151. """
  152. Returns a ones tensor with the same shape as input tensor.
  153. :param inp: input tensor.
  154. :return: output ones tensor.
  155. Examples:
  156. .. testcode::
  157. import numpy as np
  158. from megengine import tensor
  159. import megengine.functional as F
  160. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  161. out = F.ones_like(inp)
  162. print(out.numpy())
  163. Outputs:
  164. .. testoutput::
  165. [[1 1 1]
  166. [1 1 1]]
  167. """
  168. return full_like(inp, 1.0)
  169. def full_like(
  170. inp: Union[Tensor, SymbolVar], value: Union[int, float]
  171. ) -> Union[Tensor, SymbolVar]:
  172. """
  173. Returns a tensor filled with given value with the same shape as input tensor.
  174. :param inp: input tensor.
  175. :param value: target value.
  176. :return: output tensor.
  177. Examples:
  178. .. testcode::
  179. import numpy as np
  180. from megengine import tensor
  181. import megengine.functional as F
  182. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  183. out = F.full_like(inp, 2)
  184. print(out.numpy())
  185. Outputs:
  186. .. testoutput::
  187. [[2 2 2]
  188. [2 2 2]]
  189. """
  190. (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
  191. if inp.shape is ():
  192. return x
  193. return broadcast_to(x, inp.shape)
  194. def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
  195. """
  196. Broadcasts a tensor to given shape.
  197. :param inp: input tensor.
  198. :param shape: target shape.
  199. :return: output tensor.
  200. Examples:
  201. .. testcode::
  202. import numpy as np
  203. from megengine import tensor
  204. import megengine.functional as F
  205. data = tensor(np.arange(0, 3, dtype=np.float32).reshape(3))
  206. out = F.broadcast_to(data, (2, 3))
  207. print(out.numpy())
  208. Outputs:
  209. .. testoutput::
  210. [[0. 1. 2.]
  211. [0. 1. 2.]]
  212. """
  213. return _broadcast(inp, shape)
  214. def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
  215. r"""
  216. Concat some tensors
  217. :param inps: input tensors to concat.
  218. :param axis: over which dimension the tensors are concatenated. Default: 0
  219. :param device: which device output will be. Default: None
  220. :return: output tensor.
  221. Examples:
  222. .. testcode::
  223. import numpy as np
  224. from megengine import tensor
  225. import megengine.functional as F
  226. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  227. data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  228. out = F.concat([data1, data2])
  229. print(out.numpy())
  230. Outputs:
  231. .. testoutput::
  232. [[ 0. 1. 2.]
  233. [ 3. 4. 5.]
  234. [ 6. 7. 8.]
  235. [ 9. 10. 11.]]
  236. """
  237. if len(inps) == 1:
  238. return inps[0]
  239. # FIXME: remove this convert_inputs
  240. inps = convert_inputs(*inps, device=device)
  241. if device is None:
  242. device = get_device(inps)
  243. device = as_device(device)
  244. (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
  245. return result
  246. def stack(inps, axis=0, device=None):
  247. """
  248. Concats a sequence of tensors along a new axis.
  249. The input tensors must have the same shape.
  250. :param inps: input tensors.
  251. :param axis: which axis will be concatenated.
  252. :param device: the device output will be. Default: None
  253. :return: output concatenated tensor.
  254. Examples:
  255. .. testcode::
  256. import numpy as np
  257. from megengine import tensor
  258. import megengine.functional as F
  259. x1 = tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
  260. x2 = tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
  261. out = F.stack([x1, x2], axis=0)
  262. print(out.numpy())
  263. Outputs:
  264. .. testoutput::
  265. [[0. 1. 2.]
  266. [6. 7. 8.]]
  267. """
  268. if len(inps) > 0 and not isinstance(inps[0].shape, inps[0].__class__):
  269. shapes = {arr.shape for arr in inps}
  270. if len(shapes) != 1:
  271. raise ValueError("All input tensors must have the same shape")
  272. inps = [expand_dims(inp, axis=axis) for inp in inps]
  273. return concat(inps, axis=axis, device=device)
  274. def split(inp, nsplits_or_sections, axis=0):
  275. """
  276. Splits the input tensor into several smaller tensors.
  277. When nsplits_or_sections is int, the last tensor may be smaller than others.
  278. :param inp: input tensor.
  279. :param nsplits_or_sections: number of sub tensors or sections information list.
  280. :param axis: which axis will be splited.
  281. :return: output tensor list.
  282. Examples:
  283. .. testcode::
  284. import os
  285. import numpy as np
  286. from megengine import tensor
  287. import megengine.functional as F
  288. x = tensor(np.random.random((10, 20)), dtype=np.float32)
  289. y = F.split(x, 3)
  290. z = F.split(x, [6, 17], axis=1)
  291. print([i.numpy().shape for i in y])
  292. print([i.numpy().shape for i in z])
  293. Outputs:
  294. .. testoutput::
  295. [(4, 20), (3, 20), (3, 20)]
  296. [(10, 6), (10, 11), (10, 3)]
  297. """
  298. ndim = len(inp.shape)
  299. if axis >= ndim:
  300. raise ValueError("Invalid axis {}".format(axis))
  301. Ntotal = inp.shape[axis]
  302. try:
  303. Nsections = len(nsplits_or_sections) + 1
  304. is_array = True
  305. except TypeError:
  306. Nsections = int(nsplits_or_sections)
  307. is_array = False
  308. if is_array:
  309. div_points = [0] + list(nsplits_or_sections) + [Ntotal]
  310. for i in range(1, len(div_points)):
  311. if div_points[i - 1] >= div_points[i]:
  312. raise ValueError(
  313. "Invalid nsplits_or_secions: {}".format(nsplits_or_sections)
  314. )
  315. else: # scalar
  316. if Nsections <= 0:
  317. raise ValueError("Number sections must be larger than 0")
  318. if Nsections > Ntotal:
  319. raise ValueError(
  320. "The size {} at dim {} cannot be split into {} sections".format(
  321. Ntotal, axis, Nsections
  322. )
  323. )
  324. func = (
  325. floor_div
  326. if isinstance(Nsections, (SymbolVar, Tensor))
  327. else lambda x, y: x // y
  328. )
  329. div_points = [0] + [
  330. func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
  331. ]
  332. for i in range(2, Nsections + 1):
  333. div_points[i] = div_points[i - 1] + div_points[i]
  334. sub_tensors = []
  335. for i in range(Nsections):
  336. l = div_points[i]
  337. r = div_points[i + 1]
  338. slices = tuple(
  339. [slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1)
  340. )
  341. sub_tensors.append(inp[slices])
  342. return sub_tensors
  343. def _get_idx(index, axis):
  344. index_dims = len(index.shape)
  345. idx = []
  346. for i in range(index_dims):
  347. if i != axis:
  348. shape = [1] * index_dims
  349. shape[i] = index.shape[i]
  350. arange = linspace(
  351. 0, index.shape[i] - 1, index.shape[i], device=index.device,
  352. )
  353. arange = (
  354. broadcast_to(arange.reshape(*shape), index.shape)
  355. .reshape(-1)
  356. .astype(np.int32)
  357. )
  358. idx.append(arange)
  359. else:
  360. idx.append(index.reshape(-1))
  361. return tuple(idx)
  362. def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
  363. # TODO: rewrite doc
  364. r"""
  365. Gathers data from input tensor on axis using index.
  366. For a 3-D tensor, the output is specified by::
  367. out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
  368. out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
  369. out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
  370. if input tensor is a n-dimensional tensor with size
  371. :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
  372. then index must be a n-dimensional tensor with size
  373. :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
  374. output will have the same size as index.
  375. :param inp: input tensor.
  376. :param axis: along which axis to index.
  377. :param index: indices of elements to gather.
  378. :return: output tensor.
  379. Examples:
  380. .. testcode::
  381. import megengine.functional as F
  382. from megengine import tensor
  383. inp = tensor([
  384. [1,2], [3,4], [5,6],
  385. ])
  386. index = tensor([[0,2], [1,0]])
  387. oup = F.gather(inp, 0, index)
  388. print(oup.numpy())
  389. Outputs:
  390. .. testoutput::
  391. [[1 6]
  392. [3 2]]
  393. """
  394. input_shape = inp.shape
  395. index_shape = index.shape
  396. input_dims = len(input_shape)
  397. index_dims = len(index_shape)
  398. if input_dims != index_dims:
  399. raise ValueError(
  400. "The index tensor must have same dimensions as input tensor, "
  401. "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
  402. )
  403. if axis < 0 or axis >= input_dims:
  404. raise ValueError(
  405. "Index axis {} is output of bounds, should in range [0 {})".format(
  406. axis, input_dims
  407. )
  408. )
  409. for i in range(input_dims):
  410. if i != axis and input_shape[i] != index_shape[i]:
  411. raise ValueError(
  412. "The input {} and index {} must have the same size apart from axis {}".format(
  413. input_shape, index_shape, axis
  414. )
  415. )
  416. idx = _get_idx(index, axis)
  417. return inp[idx].reshape(index.shape) # pylint: disable=no-member
  418. def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
  419. # TODO: rewrite doc
  420. r"""
  421. Writes all values from the tensor source into input tensor
  422. at the indices specified in the index tensor.
  423. For each value in source, its output index is specified by its index
  424. in source for ``axis != dimension`` and by the corresponding value in
  425. index for ``axis = dimension``.
  426. For a 3-D tensor, input tensor is updated as::
  427. inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0
  428. inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1
  429. inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2
  430. ``inp``, ``index`` and ``source`` should have same number of dimensions.
  431. It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
  432. for all dimensions ``d``.
  433. Moreover, the values of index must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
  434. .. note::
  435. Please notice that, due to performance issues, the result is uncertain on the GPU device
  436. if scattering different positions from source to the same destination position
  437. regard to index tensor.
  438. Check the following examples, the oup[0][2] is maybe
  439. from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339
  440. if set the index[1][2] from 1 to 0.
  441. :param inp: inp tensor which to be scattered.
  442. :param axis: axis along which to index.
  443. :param index: indices of elements to scatter.
  444. :param source: source element(s) to scatter.
  445. :return: output tensor.
  446. Examples:
  447. .. testcode::
  448. import numpy as np
  449. import megengine.functional as F
  450. from megengine import tensor
  451. inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
  452. source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
  453. index = tensor([[0,2,0,2,1],[2,0,1,1,2]])
  454. oup = F.scatter(inp, 0, index,source)
  455. print(oup.numpy())
  456. Outputs:
  457. .. testoutput::
  458. [[0.9935 0.0718 0.2256 0. 0. ]
  459. [0. 0. 0.5939 0.357 0.4396]
  460. [0.7723 0.9465 0. 0.8926 0.4576]]
  461. """
  462. input_shape = inp.shape
  463. index_shape = index.shape
  464. source_shape = source.shape
  465. input_dims = len(input_shape)
  466. index_dims = len(index_shape)
  467. source_dims = len(source_shape)
  468. if input_dims != index_dims or input_dims != source_dims:
  469. raise ValueError("The input, source and index tensor must have same dimensions")
  470. if axis < 0 or axis >= input_dims:
  471. raise ValueError(
  472. "Index axis {} is output of bounds, should in range [0 {})".format(
  473. axis, input_dims
  474. )
  475. )
  476. for i in range(source_dims):
  477. if source_shape[i] > input_shape[i]:
  478. raise ValueError(
  479. "The each shape size for source {} must be less than or equal to input {} ".format(
  480. source_shape, input_shape
  481. )
  482. )
  483. for i in range(index_dims):
  484. if index_shape[i] != source_shape[i]:
  485. raise ValueError(
  486. "The each shape size for index {} must be equal to source {} ".format(
  487. index_shape, source_shape
  488. )
  489. )
  490. for i in range(index_dims):
  491. if i != axis and index_shape[i] > input_shape[i]:
  492. raise ValueError(
  493. "The index {} must be less than or equal to input {} size apart from axis {}".format(
  494. index_shape, input_shape, axis
  495. )
  496. )
  497. idx = _get_idx(index, axis)
  498. inp[idx] = source.flatten()
  499. return inp
  500. def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
  501. r"""
  502. Selects elements either from Tensor x or Tensor y, according to mask.
  503. .. math::
  504. \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i
  505. :param mask: a mask used for choosing ``x`` or ``y``.
  506. :param x: first choice.
  507. :param y: second choice.
  508. :return: output tensor.
  509. Examples:
  510. .. testcode::
  511. import numpy as np
  512. from megengine import tensor
  513. import megengine.functional as F
  514. mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool))
  515. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  516. dtype=np.float32))
  517. y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
  518. out = F.where(mask, x, y)
  519. print(out.numpy())
  520. Outputs:
  521. .. testoutput::
  522. [[1. 6.]
  523. [7. 4.]]
  524. """
  525. if not isinstance(x, Tensor):
  526. raise TypeError("input x must be a tensor")
  527. if not isinstance(y, Tensor):
  528. raise TypeError("input y must be a tensor")
  529. if not isinstance(mask, Tensor):
  530. raise TypeError("mask must be a tensor")
  531. if mask.dtype != np.bool_:
  532. raise ValueError("mask must be bool")
  533. if x.device != mask.device:
  534. raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
  535. dtype = dtype_promotion(x, y)
  536. if x.dtype != dtype:
  537. x = x.astype(dtype)
  538. if y.dtype != dtype:
  539. y = y.astype(dtype)
  540. v0, index0 = cond_take(mask, x)
  541. v1, index1 = cond_take(~mask, y)
  542. out = concat([v0, v1])
  543. out[index0] = v0
  544. out[index1] = v1
  545. out = out.reshape(x.shape)
  546. return out
  547. def cond_take(mask: Tensor, x: Tensor) -> Tensor:
  548. r"""
  549. Takes elements from data if specific condition is satisfied on mask.
  550. This operator has two outputs: the first is the elements taken,
  551. and the second is the indices corresponding to those elements;
  552. they are both 1-dimensional. High-dimension input would first be flattened.
  553. :param mask: condition param; must be the same shape with data.
  554. :param x: input tensor from which to take elements.
  555. Examples:
  556. .. testcode::
  557. import numpy as np
  558. from megengine import tensor
  559. import megengine.functional as F
  560. mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
  561. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  562. dtype=np.float32))
  563. v, index = F.cond_take(mask, x)
  564. print(v.numpy(), index.numpy())
  565. Outputs:
  566. .. testoutput::
  567. [1. 4.] [0 3]
  568. """
  569. if not isinstance(x, (Tensor, SymbolVar)):
  570. raise TypeError("input must be a tensor")
  571. if not isinstance(mask, (Tensor, SymbolVar)):
  572. raise TypeError("mask must be a tensor")
  573. if mask.dtype != np.bool_:
  574. raise ValueError("mask must be bool")
  575. if x.device != mask.device:
  576. raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
  577. op = builtin.CondTake()
  578. v, index = apply(op, x, mask)
  579. return v, index
  580. def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  581. r"""
  582. Swaps shapes and strides according to given pattern.
  583. :param inp: input tensor.
  584. :param pattern: a list of integers including 0, 1, ... , ``ndim``-1,
  585. and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
  586. * (``'x'``) -> make a 0d (scalar) into a 1d vector
  587. * (0, 1) -> identity for 2d vectors
  588. * (1, 0) -> inverts the first and second dimensions
  589. * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
  590. * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
  591. * (2, 0, 1) -> AxBxC to CxAxB
  592. * (0, ``'x'``, 1) -> AxB to Ax1xB
  593. * (1, ``'x'``, 0) -> AxB to Bx1xA
  594. * (1,) -> this removes dimensions 0. It must be a broadcastable dimension (1xA to A)
  595. :return: output tensor.
  596. Examples:
  597. .. testcode::
  598. import numpy as np
  599. from megengine import tensor
  600. import megengine.functional as F
  601. x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
  602. out = F.transpose(x, (1, 0))
  603. print(out.numpy())
  604. Outputs:
  605. .. testoutput::
  606. [[1 0]
  607. [1 0]]
  608. """
  609. return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern))
  610. def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
  611. r"""
  612. Reshapes a tensor to given target shape; total number of logical elements must
  613. remain unchanged
  614. :param inp: input tensor.
  615. :param target_shape: target shape, it can contain an element of -1 representing ``unspec_axis``.
  616. Examples:
  617. .. testcode::
  618. import numpy as np
  619. from megengine import tensor
  620. import megengine.functional as F
  621. x = tensor(np.arange(12, dtype=np.int32))
  622. out = F.reshape(x, (3, 4))
  623. print(out.numpy())
  624. Outputs:
  625. .. testoutput::
  626. [[ 0 1 2 3]
  627. [ 4 5 6 7]
  628. [ 8 9 10 11]]
  629. """
  630. return inp.reshape(target_shape)
  631. def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
  632. r"""
  633. Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``.
  634. :param inp: input tensor.
  635. :param start_axis: start dimension that the sub-tensor to be flattened. Default: 0
  636. :param end_axis: end dimension that the sub-tensor to be flattened. Default: -1
  637. :return: output tensor.
  638. Examples:
  639. .. testcode::
  640. import numpy as np
  641. from megengine import tensor
  642. import megengine.functional as F
  643. inp_shape = (2, 2, 3, 3)
  644. x = tensor(
  645. np.arange(36, dtype=np.int32).reshape(inp_shape),
  646. )
  647. out = F.flatten(x, 2)
  648. print(x.numpy().shape)
  649. print(out.numpy().shape)
  650. Outputs:
  651. .. testoutput::
  652. (2, 2, 3, 3)
  653. (2, 2, 9)
  654. """
  655. target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,)
  656. if end_axis != -1:
  657. target_shape += (*inp.shape[end_axis + 1 :],)
  658. return inp.reshape(*target_shape)
  659. def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
  660. r"""
  661. Adds dimension before given axis.
  662. :param inp: input tensor.
  663. :param axis: place of new axes.
  664. :return: output tensor.
  665. Examples:
  666. .. testcode::
  667. import numpy as np
  668. from megengine import tensor
  669. import megengine.functional as F
  670. x = tensor([1, 2])
  671. out = F.expand_dims(x, 0)
  672. print(out.numpy().shape)
  673. Outputs:
  674. .. testoutput::
  675. (1, 2)
  676. """
  677. def get_axes():
  678. try:
  679. return [int(axis)]
  680. except (TypeError, ValueError):
  681. pass
  682. return list(map(int, axis))
  683. axis = get_axes()
  684. try:
  685. ndim = inp.ndim + len(axis)
  686. axis = sorted(i + ndim if i < 0 else i for i in axis)
  687. except ValueError:
  688. if any([ind < 0 for ind in axis]):
  689. raise IndexError(
  690. "Does not support negative index when tensor's ndim is unknown"
  691. )
  692. axis = sorted(axis)
  693. assert axis, "axis could not be empty"
  694. if inp._isscalar():
  695. assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0])
  696. if len(axis) == 1:
  697. inp = copy(inp, device=None)
  698. inp._unsetscalar()
  699. return inp
  700. axis = axis[1:]
  701. op = builtin.AddAxis(axis=axis)
  702. (result,) = apply(op, inp)
  703. return result
  704. def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor:
  705. r"""
  706. Removes dimension of shape 1.
  707. :param inp: input tensor.
  708. :param axis: place of axis to be removed.
  709. :return: output tensor.
  710. Examples:
  711. .. testcode::
  712. import numpy as np
  713. from megengine import tensor
  714. import megengine.functional as F
  715. x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  716. out = F.squeeze(x, 3)
  717. print(out.numpy().shape)
  718. Outputs:
  719. .. testoutput::
  720. (1, 1, 2)
  721. """
  722. return _remove_axis(inp, axis)
  723. def linspace(
  724. start: Union[int, float, Tensor],
  725. stop: Union[int, float, Tensor],
  726. num: Union[int, Tensor],
  727. dtype="float32",
  728. device: Optional[CompNode] = None,
  729. ) -> Tensor:
  730. r"""
  731. Returns equally spaced numbers over a specified interval.
  732. :param start: starting value of the squence, shoule be scalar.
  733. :param stop: last value of the squence, shoule be scalar.
  734. :param num: number of values to generate.
  735. :param dtype: result data type.
  736. :return: generated tensor.
  737. Examples:
  738. .. testcode::
  739. import numpy as np
  740. import megengine.functional as F
  741. a = F.linspace(3, 10, 5)
  742. print(a.numpy())
  743. Outputs:
  744. .. testoutput::
  745. [ 3. 4.75 6.5 8.25 10. ]
  746. """
  747. for item in (start, stop, num):
  748. cur_device = getattr(item, "device", None)
  749. if device is None:
  750. device = cur_device
  751. else:
  752. if not (cur_device is None or device == cur_device):
  753. raise ("ambiguous device for linspace opr")
  754. is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num])
  755. if any(is_symbolvar) and not all(is_symbolvar):
  756. raise TypeError("start, stop and num should all be VarNode or none of them")
  757. if not isinstance(start, (Tensor, SymbolVar)):
  758. start = Tensor(start, device=device)
  759. if not isinstance(stop, (Tensor, SymbolVar)):
  760. stop = Tensor(stop, device=device)
  761. if not isinstance(num, (Tensor, SymbolVar)):
  762. num = Tensor(num, device=device)
  763. op = builtin.Linspace(comp_node=device)
  764. (result,) = apply(op, start, stop, num)
  765. if np.dtype(dtype) == np.int32:
  766. return result.astype(dtype)
  767. return result
  768. def arange(
  769. start: Union[int, float, Tensor] = 0,
  770. stop: Optional[Union[int, float, Tensor]] = None,
  771. step: Union[int, float, Tensor] = 1,
  772. dtype="float32",
  773. device: Optional[CompNode] = None,
  774. ) -> Tensor:
  775. r"""
  776. Returns a tensor with values from start to stop with adjacent interval step.
  777. :param start: starting value of the squence, shoule be scalar.
  778. :param stop: ending value of the squence, shoule be scalar.
  779. :param step: gap between each pair of adjacent values. Default: 1
  780. :param dtype: result data type.
  781. :return: generated tensor.
  782. Examples:
  783. .. testcode::
  784. import numpy as np
  785. import megengine.functional as F
  786. a = F.arange(5)
  787. print(a.numpy())
  788. Outputs:
  789. Outputs:
  790. .. testoutput::
  791. [0. 1. 2. 3. 4.]
  792. """
  793. if stop is None:
  794. start, stop = 0, start
  795. start = Tensor(start, dtype="float32")
  796. stop = Tensor(stop, dtype="float32")
  797. step = Tensor(step, dtype="float32")
  798. num = ceil((stop - start) / step)
  799. stop = start + step * (num - 1)
  800. result = linspace(start, stop, num, device=device)
  801. if np.dtype(dtype) == np.int32:
  802. return result.astype(dtype)
  803. return result
  804. def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
  805. """
  806. Repeat elements of an array.
  807. :param inp: input tensor.
  808. :param repeats: the number of repetitions for each element.
  809. :param axis: the axis along which to repeat values. By default, use the
  810. flattened input array, and return a flat output array.
  811. :return: output tensor.
  812. Examples:
  813. .. testcode::
  814. import numpy as np
  815. import megengine.functional as F
  816. from megengine import tensor
  817. x = tensor([[1, 2], [3, 4]], np.int32)
  818. y = F.repeat(x, 2, axis=0)
  819. print(y.numpy())
  820. Outputs:
  821. .. testoutput::
  822. [[1 2]
  823. [1 2]
  824. [3 4]
  825. [3 4]]
  826. """
  827. if axis is None:
  828. inp = inp.reshape(-1) # flatten
  829. axis = 0
  830. if inp._isscalar():
  831. inp._unsetscalar()
  832. shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
  833. # assume inp.ndim is not changed during trace
  834. max_axis = len(shape) - 1
  835. assert axis >= 0 and axis <= max_axis
  836. assert repeats >= 1
  837. base_shape, bcast_shape, target_shape = [], [], []
  838. if axis != 0:
  839. target_shape.append(shape[:axis])
  840. base_shape.extend([shape[: axis + 1], [1,]])
  841. bcast_shape.extend([shape[: axis + 1], [repeats,]])
  842. target_shape.extend(
  843. [shape[axis] * repeats,]
  844. )
  845. if axis + 1 <= max_axis:
  846. base_shape.append(shape[axis + 1 :])
  847. bcast_shape.append(shape[axis + 1 :])
  848. target_shape.append(shape[axis + 1 :])
  849. out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
  850. concat(target_shape)
  851. )
  852. return out
  853. def _tile_one_dim(inp, rep, axis):
  854. shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
  855. # assume inp.ndim is not changed during trace
  856. max_axis = len(shape) - 1
  857. base_shape, bcast_shape, target_shape = [], [], []
  858. if axis != 0:
  859. base_shape.append(shape[:axis])
  860. bcast_shape.append(shape[:axis])
  861. target_shape.append(shape[:axis])
  862. base_shape.extend([[1,], shape[axis:]])
  863. bcast_shape.extend([rep, shape[axis:]])
  864. target_shape.append(shape[axis] * rep)
  865. if axis + 1 <= max_axis:
  866. target_shape.append(shape[axis + 1 :])
  867. out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
  868. concat(target_shape)
  869. )
  870. return out
  871. def tile(inp: Tensor, reps: Iterable[int]):
  872. """
  873. Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d,
  874. the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``,
  875. ``inp`` is promoted to be ``d``-dimensional by prepending new axis.
  876. :param inp: input tensor.
  877. :param reps: The number of repetitions of inp along each axis.
  878. :return: output tensor.
  879. Examples:
  880. .. testcode::
  881. import numpy as np
  882. import megengine.functional as F
  883. from megengine import tensor
  884. x = tensor([[1, 2], [3, 4]], np.int32)
  885. y = F.tile(x, (2,1))
  886. print(y.numpy())
  887. Outputs:
  888. .. testoutput::
  889. [[1 2]
  890. [3 4]
  891. [1 2]
  892. [3 4]]
  893. """
  894. shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
  895. reps = astensor1d(reps, inp, dtype="int32", device=inp.device)
  896. l_shape = len(shape)
  897. l_reps = len(reps)
  898. assert (
  899. l_reps >= l_shape
  900. ), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor"
  901. for i in range(l_shape):
  902. rep = reps[i + (l_reps - l_shape)]
  903. inp = _tile_one_dim(inp, rep, i)
  904. if l_reps > l_shape:
  905. shape = inp.shape
  906. extra = reps[:-l_shape]
  907. extra_ones = ones_like(extra)
  908. base_shape = concat([extra_ones, shape])
  909. bcast_shape = concat([extra, shape])
  910. target_shape = concat([extra, shape])
  911. inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)
  912. return inp
  913. def copy(inp, device=None):
  914. r"""
  915. Copies tensor to another device.
  916. :param inp: input tensor.
  917. :param device: destination device.
  918. Examples:
  919. .. testcode::
  920. import numpy as np
  921. import platform
  922. from megengine import tensor
  923. from megengine.device import get_device_count
  924. import megengine.functional as F
  925. x = tensor([1, 2, 3], np.int32)
  926. if 1 == get_device_count("gpu"):
  927. y = F.copy(x, "cpu1")
  928. print(y.numpy())
  929. else:
  930. y = F.copy(x, "xpu1")
  931. print(y.numpy())
  932. Outputs:
  933. .. testoutput::
  934. [1 2 3]
  935. """
  936. if device is None:
  937. return apply(Identity(), inp)[0]
  938. return apply(Copy(comp_node=as_device(device).to_c()), inp)[0]
  939. def roll(
  940. inp: Tensor,
  941. shift: Union[int, Iterable[int]],
  942. axis: Optional[Union[int, Iterable[int]]] = None,
  943. ):
  944. """
  945. Roll the tensor along the given axis(or axes). Elements that are shifted
  946. beyond the last position are re-introduced at the first position.
  947. :param inp: input tensor.
  948. :param shift: the number of places by which the elements of the tensor are
  949. shifted. If shift is a tuple, axis must be a tuple of the same size,
  950. and each axis will be rolled by the corresponding shift value.
  951. :param axis: axis along which to roll. If axis is not specified, the tensor
  952. will be flattened before rolling and then restored to the original shape.
  953. Duplicate axes is allowed if it is a tuple. Default: None.
  954. Examples:
  955. .. testcode::
  956. import numpy as np
  957. from megengine import tensor
  958. import megengine.functional as F
  959. x = tensor([[1,2],[3,4],[5,6]], np.int32)
  960. y = F.roll(x, 1, 0)
  961. print(y.numpy())
  962. Outputs:
  963. .. testoutput::
  964. [[5 6]
  965. [1 2]
  966. [3 4]]
  967. """
  968. shp_bak = None
  969. if axis is None:
  970. shp_bak = inp.shape
  971. inp = inp.flatten()
  972. axis = 0
  973. shp = inp.shape
  974. dim = len(shp)
  975. if isinstance(shift, int):
  976. assert isinstance(axis, int)
  977. shift, axis = [shift,], [axis,]
  978. assert len(shift) == len(axis)
  979. out = inp
  980. for i in range(len(shift)):
  981. axis_ = axis[i]
  982. shift_ = shift[i]
  983. axis_normalized_ = axis_ + dim if axis_ < 0 else axis_
  984. assert (
  985. dim > axis_normalized_ >= 0
  986. ), "axis out of range (expected to be in range of [{}, {}], but got {})".format(
  987. -dim, dim - 1, axis_
  988. )
  989. if shift_ == 0:
  990. continue
  991. size = shp[axis_normalized_]
  992. if shift_ > 0:
  993. a, b = split(out, [size - shift_,], axis=axis_normalized_)
  994. else:
  995. a, b = split(out, [-shift_,], axis=axis_normalized_)
  996. out = concat((b, a), axis=axis_normalized_)
  997. if shp_bak is not None:
  998. out = out.reshape(shp_bak)
  999. return out
  1000. def cumsum(inp: Tensor, axis: int):
  1001. """
  1002. Computes the cumulative sum of elements along given axis.
  1003. :param inp: input tensor.
  1004. :param axis: axis along which cumsum is performed.
  1005. Examples:
  1006. .. testcode::
  1007. from megengine import tensor
  1008. import megengine.functional as F
  1009. x = tensor([[1, 2, 3], [4, 5, 6]], "int32")
  1010. y = F.cumsum(x, 1)
  1011. print(y.numpy())
  1012. Outputs:
  1013. .. testoutput::
  1014. [[ 1 3 6]
  1015. [ 4 9 15]]
  1016. """
  1017. assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
  1018. assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis)
  1019. op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
  1020. return apply(op, inp)[0]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台