You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 30 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import math
  11. from itertools import accumulate
  12. from typing import Iterable, List, Optional, Sequence, Tuple, Union
  13. import numpy as np
  14. from ..core._imperative_rt import CompNode
  15. from ..core._wrap import device as as_device
  16. from ..core.ops import builtin
  17. from ..core.ops._internal import param_defs as P
  18. from ..core.ops.special import Const
  19. from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
  20. from ..core.tensor.utils import (
  21. astensor1d,
  22. convert_inputs,
  23. convert_single_value,
  24. dtype_promotion,
  25. get_device,
  26. )
  27. from ..device import get_default_device
  28. from ..tensor import Tensor
  29. from .elemwise import ceil
  30. __all__ = [
  31. "add_axis",
  32. "arange",
  33. "broadcast",
  34. "concat",
  35. "cond_take",
  36. "transpose",
  37. "add_axis",
  38. "eye",
  39. "flatten",
  40. "full",
  41. "full_like",
  42. "gather",
  43. "identity",
  44. "linspace",
  45. "ones",
  46. "ones_like",
  47. "param_pack_concat",
  48. "param_pack_split",
  49. "reshape",
  50. "remove_axis",
  51. "split",
  52. "stack",
  53. "scatter",
  54. "transpose",
  55. "where",
  56. "zeros",
  57. "zeros_like",
  58. ]
  59. def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
  60. """Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
  61. :param shape: expected shape of otuput tensor.
  62. :param m: number of columns. Default: None
  63. :param dtype: data type. Default: None
  64. :param device: compute node of the matrix. Default: None
  65. :return: eye matrix.
  66. Examples:
  67. .. testcode::
  68. import numpy as np
  69. import megengine.functional as F
  70. data_shape = (4, 6)
  71. out = F.eye(data_shape, dtype=np.float32)
  72. print(out.numpy())
  73. Outputs:
  74. .. testoutput::
  75. [[1. 0. 0. 0. 0. 0.]
  76. [0. 1. 0. 0. 0. 0.]
  77. [0. 0. 1. 0. 0. 0.]
  78. [0. 0. 0. 1. 0. 0.]]
  79. """
  80. op = builtin.Eye(k=0, dtype=dtype, comp_node=device)
  81. (result,) = apply(op, Tensor(shape, dtype="int32", device=device))
  82. return result
  83. def full(shape, value, dtype="float32", device=None):
  84. """Returns a tensor with given shape and value.
  85. """
  86. if isinstance(shape, int):
  87. shape = (shape,)
  88. if device is None:
  89. device = get_default_device()
  90. (x,) = Const(value, dtype=dtype, device=device)(
  91. Tensor(value, dtype=dtype, device=device)
  92. )
  93. return broadcast(x, shape)
  94. def ones(shape, dtype="float32", device=None):
  95. """Returns a ones tensor with given shape.
  96. :param inp: input tensor.
  97. :return: output zero tensor.
  98. Examples:
  99. .. testcode::
  100. import megengine.functional as F
  101. out = F.ones((2, 1))
  102. print(out.numpy())
  103. Outputs:
  104. .. testoutput::
  105. [[1.]
  106. [1.]]
  107. """
  108. return full(shape, 1.0, dtype=dtype, device=device)
  109. def zeros(shape, dtype="float32", device=None):
  110. """Returns a zero tensor with given shape.
  111. """
  112. return full(shape, 0.0, dtype=dtype, device=device)
  113. def zeros_like(inp: Tensor) -> Tensor:
  114. """Returns a zero tensor with the same shape as input tensor.
  115. :param inp: input tensor.
  116. :return: output zero tensor.
  117. Examples:
  118. .. testcode::
  119. import numpy as np
  120. from megengine import tensor
  121. import megengine.functional as F
  122. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  123. out = F.zeros_like(inp)
  124. print(out.numpy())
  125. Outputs:
  126. .. testoutput::
  127. [[0 0 0]
  128. [0 0 0]]
  129. """
  130. return zeros(inp.shape, dtype=inp.dtype, device=inp.device)
  131. def ones_like(inp: Tensor) -> Tensor:
  132. """Returns a identity tensor with the same shape as input tensor.
  133. """
  134. return ones(inp.shape, dtype=inp.dtype, device=inp.device)
  135. def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
  136. """Returns a tensor filled with given value with the same shape as input tensor.
  137. """
  138. return full(inp.shape, value, dtype=inp.dtype, device=inp.device)
  139. def identity(inp: Tensor) -> Tensor:
  140. """Applies an identity transform to the input tensor.
  141. :param inp: input tensor.
  142. :return: output tensor.
  143. """
  144. op = builtin.Identity()
  145. (data,) = convert_inputs(inp)
  146. (output,) = apply(op, data)
  147. return output
  148. def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
  149. """
  150. Broadcasts a tensor to given shape.
  151. :param inp: input tensor.
  152. :param shape: target shape.
  153. :return: output tensor.
  154. Examples:
  155. .. testcode::
  156. import numpy as np
  157. from megengine import tensor
  158. import megengine.functional as F
  159. data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  160. out = F.broadcast(data, (4, 2, 3))
  161. print(out.numpy())
  162. Outputs:
  163. .. testoutput::
  164. [[[0. 1. 2.]
  165. [3. 4. 5.]]
  166. [[0. 1. 2.]
  167. [3. 4. 5.]]
  168. [[0. 1. 2.]
  169. [3. 4. 5.]]
  170. [[0. 1. 2.]
  171. [3. 4. 5.]]]
  172. """
  173. shape = astensor1d(shape, inp, dtype="int32", device=inp.device)
  174. (result,) = apply(builtin.Broadcast(), inp, shape)
  175. return result
  176. def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
  177. r"""
  178. Concat some tensors
  179. :param inps: input tensors to concat.
  180. :param axis: dimension over which the tensors are concatenated. Default: 0
  181. :param device: comp node output on. Default: None
  182. :return: output tensor.
  183. Examples:
  184. .. testcode::
  185. import numpy as np
  186. from megengine import tensor
  187. import megengine.functional as F
  188. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  189. data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  190. out = F.concat([data1, data2])
  191. print(out.numpy())
  192. Outputs:
  193. .. testoutput::
  194. [[ 0. 1. 2.]
  195. [ 3. 4. 5.]
  196. [ 6. 7. 8.]
  197. [ 9. 10. 11.]]
  198. """
  199. if len(inps) == 1:
  200. return inps[0]
  201. dtype = dtype_promotion(inps)
  202. if device is None:
  203. device = get_device(inps)
  204. device = as_device(device)
  205. def convert(x):
  206. return convert_single_value(x, inps, dtype=dtype)
  207. inps = tuple(map(convert, inps))
  208. (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
  209. return result
  210. def stack(inps, axis=0, device=None):
  211. """Concats a sequence of tensors along a new axis.
  212. The input tensors must have the same shape.
  213. :param inps: input tensors.
  214. :param axis: which axis will be concatenated.
  215. :param device: The comp node output on. Default: None
  216. :return: output concatenated tensor.
  217. Examples:
  218. .. testcode::
  219. import numpy as np
  220. from megengine import tensor
  221. import megengine.functional as F
  222. x1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  223. x2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  224. out = F.stack([x1, x2], axis=0)
  225. print(out.numpy())
  226. Outputs:
  227. .. testoutput::
  228. [[[ 0. 1. 2.]
  229. [ 3. 4. 5.]]
  230. [[ 6. 7. 8.]
  231. [ 9. 10. 11.]]]
  232. """
  233. if len(inps) > 0 and not isinstance(inps[0].shape, inps[0].__class__):
  234. shapes = {arr.shape for arr in inps}
  235. if len(shapes) != 1:
  236. raise ValueError("All input tensors must have the same shape")
  237. inps = [add_axis(inp, axis=axis) for inp in inps]
  238. return concat(inps, axis=axis, device=device)
  239. def split(inp, nsplits_or_sections, axis=0):
  240. """Splits the input tensor into several smaller tensors.
  241. When nsplits_or_sections is int, the last tensor may be smaller than others.
  242. :param inp: input tensor.
  243. :param nsplits_or_sections: number of sub tensors or section information list.
  244. :param axis: which axis will be splited.
  245. :return: output tensor list.
  246. Examples:
  247. .. testcode::
  248. import numpy as np
  249. from megengine import tensor
  250. import megengine.functional as F
  251. x = tensor(np.random.random((2,3,4,5)), dtype=np.float32)
  252. out = F.split(x, 2, axis=3)
  253. print(out[0].shape, out[1].shape)
  254. Outputs:
  255. .. testoutput::
  256. (2, 3, 4, 3) (2, 3, 4, 2)
  257. """
  258. sub_tensors = []
  259. sections = []
  260. def swapaxis(inp, src, dst):
  261. if src == dst:
  262. return inp
  263. shape = [i for i in range(inp.ndim)]
  264. shape[src] = dst
  265. shape[dst] = src
  266. return inp.transpose(shape)
  267. inp = swapaxis(inp, 0, axis)
  268. if isinstance(nsplits_or_sections, int):
  269. incr_step = ceil(inp.shape[0] / nsplits_or_sections)
  270. nsplits = nsplits_or_sections
  271. while nsplits > 0:
  272. nsplits -= 1
  273. sections.append(incr_step.astype("int32"))
  274. incr_step += nsplits_or_sections
  275. else:
  276. sections = nsplits_or_sections
  277. st = 0
  278. for se in sections:
  279. sub_tensors.append(swapaxis(inp[st:se], axis, 0))
  280. st = se
  281. if st < inp.shape[0]:
  282. sub_tensors.append(swapaxis(inp[st:], axis, 0))
  283. return sub_tensors
  284. def _get_idx(index, axis):
  285. index_dims = len(index.shape)
  286. idx = []
  287. for i in range(index_dims):
  288. if i != axis:
  289. shape = [1] * index_dims
  290. shape[i] = index.shape[i]
  291. arange = linspace(
  292. 0, index.shape[i] - 1, index.shape[i], device=index.device,
  293. )
  294. arange = (
  295. arange.reshape(*shape)
  296. .broadcast(index.shape)
  297. .reshape(-1)
  298. .astype(np.int32)
  299. )
  300. idx.append(arange)
  301. else:
  302. idx.append(index.reshape(-1))
  303. return tuple(idx)
  304. def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
  305. r"""Gathers data from inp on axis using index.
  306. For a 3-D tensor, the output is specified by::
  307. out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
  308. out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
  309. out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
  310. if inp is an n-dimensional tensor with size
  311. :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
  312. then index must be an n-dimensional tensor with size
  313. :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
  314. output will have the same size as index.
  315. :param inp: input tensor.
  316. :param axis: axis along which to index.
  317. :param index: indices of elements to gather.
  318. :return: output tensor.
  319. Examples:
  320. .. testcode::
  321. import megengine.functional as F
  322. from megengine import tensor
  323. inp = tensor([
  324. [1,2], [3,4], [5,6],
  325. ])
  326. index = tensor([[0,2], [1,0]])
  327. oup = F.gather(inp, 0, index)
  328. print(oup.numpy())
  329. Outputs:
  330. .. testoutput::
  331. [[1 6]
  332. [3 2]]
  333. """
  334. input_shape = inp.shape
  335. index_shape = index.shape
  336. input_dims = len(input_shape)
  337. index_dims = len(index_shape)
  338. if input_dims != index_dims:
  339. raise ValueError(
  340. "The index tensor must have same dimensions as input tensor, "
  341. "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
  342. )
  343. if axis < 0 or axis >= input_dims:
  344. raise ValueError(
  345. "Index axis {} is output of bounds, should in range [0 {})".format(
  346. axis, input_dims
  347. )
  348. )
  349. for i in range(input_dims):
  350. if i != axis and input_shape[i] != index_shape[i]:
  351. raise ValueError(
  352. "The input {} and index {} must have the same size apart from axis {}".format(
  353. input_shape, index_shape, axis
  354. )
  355. )
  356. idx = _get_idx(index, axis)
  357. return inp[idx].reshape(index.shape) # pylint: disable=no-member
  358. def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
  359. r"""Writes all values from the tensor source into inp
  360. at the indices specified in the index tensor.
  361. For each value in source, its output index is specified by its index
  362. in source for ``axis != dimension`` and by the corresponding value in
  363. index for ``axis = dimension``.
  364. For a 3-D tensor, inp is updated as::
  365. inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0
  366. inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1
  367. inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2
  368. inp, index and source should have same number of dimensions.
  369. It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
  370. for all dimensions ``d``.
  371. Moreover, the values of index must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
  372. .. note::
  373. Please notice that, due to performance issues, the result is uncertain on the GPU device
  374. if scatter difference positions from source to the same destination position
  375. regard to index tensor.
  376. Show the case using the following examples, the oup[0][2] is maybe
  377. from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339
  378. if set the index[1][2] from 1 to 0.
  379. :param inp: inp tensor which to be scattered.
  380. :param axis: axis along which to index.
  381. :param index: indices of elements to scatter.
  382. :param source: source element(s) to scatter.
  383. :return: output tensor.
  384. Examples:
  385. .. testcode::
  386. import numpy as np
  387. import megengine.functional as F
  388. from megengine import tensor
  389. inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
  390. source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
  391. index = tensor([[0,2,0,2,1],[2,0,1,1,2]])
  392. oup = F.scatter(inp, 0, index,source)
  393. print(oup.numpy())
  394. Outputs:
  395. .. testoutput::
  396. [[0.9935 0.0718 0.2256 0. 0. ]
  397. [0. 0. 0.5939 0.357 0.4396]
  398. [0.7723 0.9465 0. 0.8926 0.4576]]
  399. """
  400. input_shape = inp.shape
  401. index_shape = index.shape
  402. source_shape = source.shape
  403. input_dims = len(input_shape)
  404. index_dims = len(index_shape)
  405. source_dims = len(source_shape)
  406. if input_dims != index_dims or input_dims != source_dims:
  407. raise ValueError("The input, source and index tensor must have same dimensions")
  408. if axis < 0 or axis >= input_dims:
  409. raise ValueError(
  410. "Index axis {} is output of bounds, should in range [0 {})".format(
  411. axis, input_dims
  412. )
  413. )
  414. for i in range(source_dims):
  415. if source_shape[i] > input_shape[i]:
  416. raise ValueError(
  417. "The each shape size for source {} must be less than or equal to input {} ".format(
  418. source_shape, input_shape
  419. )
  420. )
  421. for i in range(index_dims):
  422. if index_shape[i] != source_shape[i]:
  423. raise ValueError(
  424. "The each shape size for index {} must be equal to source {} ".format(
  425. index_shape, source_shape
  426. )
  427. )
  428. for i in range(index_dims):
  429. if i != axis and index_shape[i] > input_shape[i]:
  430. raise ValueError(
  431. "The index {} must be less than or equal to input {} size apart from axis {}".format(
  432. index_shape, input_shape, axis
  433. )
  434. )
  435. idx = _get_idx(index, axis)
  436. inp[idx] = source.flatten()
  437. return inp
  438. def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
  439. r"""Selects elements either from Tensor x or Tensor y, according to mask.
  440. .. math::
  441. \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i
  442. :param mask: a mask used for choosing x or y.
  443. :param x: first choice.
  444. :param y: second choice.
  445. :return: output tensor.
  446. Examples:
  447. .. testcode::
  448. from megengine import tensor
  449. import megengine.functional as F
  450. mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool))
  451. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  452. dtype=np.float32))
  453. y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
  454. out = F.where(mask, x, y)
  455. print(out.numpy())
  456. Outputs:
  457. .. testoutput::
  458. [[1. 6.]
  459. [7. 4.]]
  460. """
  461. x, y = convert_inputs(x, y)
  462. if not isinstance(x, (TensorWrapperBase, TensorBase)):
  463. raise TypeError("input x must be a tensor")
  464. if not isinstance(y, (TensorWrapperBase, TensorBase)):
  465. raise TypeError("input y must be a tensor")
  466. if not isinstance(mask, (TensorWrapperBase, TensorBase)):
  467. raise TypeError("mask must be a tensor")
  468. if mask.dtype != np.bool_:
  469. raise ValueError("mask must be bool")
  470. if x.device != mask.device:
  471. raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
  472. v0, index0 = cond_take(mask, x)
  473. v1, index1 = cond_take(~mask, y)
  474. if v0.shape == (0,):
  475. out = v1
  476. elif v1.shape == (0,):
  477. out = v0
  478. else:
  479. out = concat([v0, v1])
  480. out[index0] = v0
  481. out[index1] = v1
  482. out = out.reshape(x.shape)
  483. return out
  484. def cond_take(mask: Tensor, x: Tensor) -> Tensor:
  485. r"""
  486. Take elements from data if specific condition is satisfied on mask.
  487. This operator has two outputs: the first is the elements taken,
  488. and the second is the indices corresponding to those elements;
  489. they are both 1-dimensional. High-dimension input would first be flattened.
  490. :param mask: condition param; must be the same shape with data.
  491. :param x: input tensor from which to take elements.
  492. Examples:
  493. .. testcode::
  494. import numpy as np
  495. from megengine import tensor
  496. import megengine.functional as F
  497. mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
  498. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  499. dtype=np.float32))
  500. v, index = F.cond_take(mask, x)
  501. print(v.numpy(), index.numpy())
  502. Outputs:
  503. .. testoutput::
  504. [1. 4.] [0 3]
  505. """
  506. if not isinstance(x, (TensorWrapperBase, TensorBase)):
  507. raise TypeError("input must be a tensor")
  508. if not isinstance(mask, (TensorWrapperBase, TensorBase)):
  509. raise TypeError("mask must be a tensor")
  510. if mask.dtype != np.bool_:
  511. raise ValueError("mask must be bool")
  512. if x.device != mask.device:
  513. raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
  514. op = builtin.CondTake()
  515. v, index = apply(op, x, mask)
  516. return v, index
  517. def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  518. r"""
  519. Swaps shapes and strides according to given pattern.
  520. :param inp: input tensor.
  521. :param pattern: a list of integers including 0, 1, ... , ``ndim``-1,
  522. and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
  523. * (``'x'``) -> make a 0d (scalar) into a 1d vector
  524. * (0, 1) -> identity for 2d vectors
  525. * (1, 0) -> inverts the first and second dimensions
  526. * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
  527. * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
  528. * (2, 0, 1) -> AxBxC to CxAxB
  529. * (0, ``'x'``, 1) -> AxB to Ax1xB
  530. * (1, ``'x'``, 0) -> AxB to Bx1xA
  531. * (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A)
  532. :return: output tensor.
  533. Examples:
  534. .. testcode::
  535. import numpy as np
  536. from megengine import tensor
  537. import megengine.functional as F
  538. x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
  539. out = F.transpose(x, (1, 0))
  540. print(out.numpy())
  541. Outputs:
  542. .. testoutput::
  543. [[1 0]
  544. [1 0]]
  545. """
  546. op = builtin.Dimshuffle(pattern)
  547. (inp,) = convert_inputs(inp)
  548. (result,) = apply(op, inp)
  549. return result
  550. dimshuffle = transpose
  551. def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
  552. r"""
  553. Reshapes a tensor to given target shape; total number of logical elements must
  554. remain unchanged
  555. :param inp: input tensor.
  556. :param target_shape: target shape, the components would be concatenated to form the
  557. target shape, and it can contain an element of -1 representing unspec_axis.
  558. Examples:
  559. .. testcode::
  560. import numpy as np
  561. from megengine import tensor
  562. import megengine.functional as F
  563. x = tensor(np.arange(12, dtype=np.int32))
  564. out = F.reshape(x, (3, 2, 2))
  565. print(out.numpy())
  566. Outputs:
  567. .. testoutput::
  568. [[[ 0 1]
  569. [ 2 3]]
  570. [[ 4 5]
  571. [ 6 7]]
  572. [[ 8 9]
  573. [10 11]]]
  574. """
  575. if isinstance(target_shape, (TensorBase, TensorWrapperBase)):
  576. target_shape = target_shape.numpy()
  577. target_shape = tuple(map(int, target_shape))
  578. unspec_axis = None
  579. for i, s in enumerate(target_shape):
  580. if s < 0:
  581. if s != -1:
  582. raise ValueError("expect shape[{}] >= -1, got {}".format(i, s))
  583. if unspec_axis is not None:
  584. raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i))
  585. unspec_axis = i
  586. # TODO: device should be None (cpu)
  587. (target_shape,) = Const(target_shape, dtype="int32", device=inp.device)(inp)
  588. if unspec_axis is None:
  589. op = builtin.Reshape()
  590. else:
  591. op = builtin.Reshape(unspec_axis=unspec_axis)
  592. (x,) = apply(op, inp, target_shape)
  593. return x
  594. AxisAddRemove = builtin.AxisAddRemove
  595. AxisDesc = AxisAddRemove.AxisDesc
  596. def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
  597. r"""Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``.
  598. :param inp: input tensor.
  599. :param start_axis: start dimension that the sub-tensor to be flattened. Default: 0
  600. :param end_axis: end dimension that the sub-tensor to be flattened. Default: -1
  601. :return: output tensor.
  602. Examples:
  603. .. testcode::
  604. import numpy as np
  605. from megengine import tensor
  606. import megengine.functional as F
  607. inp_shape = (2, 2, 3, 3)
  608. x = tensor(
  609. np.arange(36, dtype=np.int32).reshape(inp_shape),
  610. )
  611. out = F.flatten(x, 2)
  612. print(x.numpy().shape)
  613. print(out.numpy().shape)
  614. Outputs:
  615. .. testoutput::
  616. (2, 2, 3, 3)
  617. (2, 2, 9)
  618. """
  619. target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,)
  620. if end_axis != -1:
  621. target_shape += (*inp.shape[end_axis + 1 :],)
  622. return inp.reshape(*target_shape)
  623. def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
  624. r"""
  625. Adds dimension before given axis.
  626. :param inp: input tensor.
  627. :param axis: place of new axes.
  628. :return: output tensor.
  629. Examples:
  630. .. testcode::
  631. import numpy as np
  632. from megengine import tensor
  633. import megengine.functional as F
  634. x = tensor([1, 2])
  635. out = F.add_axis(x, 0)
  636. print(out.shape)
  637. Outputs:
  638. .. testoutput::
  639. (1, 2)
  640. """
  641. Param = builtin.AxisAddRemove.Param
  642. def get_axes():
  643. try:
  644. return [int(axis)]
  645. except (TypeError, ValueError):
  646. pass
  647. return list(map(int, axis))
  648. axis = get_axes()
  649. ndim = inp.ndim + len(axis)
  650. axis = sorted(i + ndim if i < 0 else i for i in axis)
  651. param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_add, axis))
  652. op = builtin.AxisAddRemove(param=param)
  653. (result,) = apply(op, inp)
  654. return result
  655. add_axis = add_axis
  656. def remove_axis(
  657. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None
  658. ) -> Tensor:
  659. r"""
  660. Removes dimension of shape 1.
  661. :param inp: input tensor.
  662. :param axis: place of axis to be removed.
  663. :return: output tensor.
  664. Examples:
  665. .. testcode::
  666. import numpy as np
  667. from megengine import tensor
  668. import megengine.functional as F
  669. x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  670. out = F.remove_axis(x, 3)
  671. print(out.shape)
  672. Outputs:
  673. .. testoutput::
  674. (1, 1, 2)
  675. """
  676. Param = builtin.AxisAddRemove.Param
  677. def get_axes():
  678. if axis is None:
  679. return [i for i, s in enumerate(inp.shape) if s == 1]
  680. try:
  681. return [int(axis)]
  682. except (TypeError, ValueError):
  683. pass
  684. return list(map(int, axis))
  685. axis = get_axes()
  686. axis = sorted(i + inp.ndim if i < 0 else i for i in axis)
  687. axis = [a - i for i, a in enumerate(axis)]
  688. param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis))
  689. op = builtin.AxisAddRemove(param=param)
  690. (result,) = apply(op, inp)
  691. return result
  692. def linspace(
  693. start: Union[int, float, Tensor],
  694. stop: Union[int, float, Tensor],
  695. num: Union[int, Tensor],
  696. dtype="float32",
  697. device: Optional[CompNode] = None,
  698. ) -> Tensor:
  699. r"""Returns equally spaced numbers over a specified interval.
  700. :param start: starting value of the squence, shoule be scalar.
  701. :param stop: last value of the squence, shoule be scalar.
  702. :param num: number of values to generate.
  703. :param dtype: result data type.
  704. :return: generated tensor.
  705. Examples:
  706. .. testcode::
  707. import numpy as np
  708. import megengine.functional as F
  709. a = F.linspace(3,10,5)
  710. print(a.numpy())
  711. Outputs:
  712. .. testoutput::
  713. [ 3. 4.75 6.5 8.25 10. ]
  714. """
  715. start = Tensor(start, device=device)
  716. stop = Tensor(stop, device=device)
  717. num = Tensor(num, device=device)
  718. device = device if device is None else device.to_c()
  719. op = builtin.Linspace(comp_node=device)
  720. (result,) = apply(op, start, stop, num)
  721. if np.dtype(dtype) == np.int32:
  722. return result.astype(dtype)
  723. return result
  724. def arange(
  725. start: Union[int, float, Tensor] = 0,
  726. end: Optional[Union[int, float, Tensor]] = None,
  727. step: Union[int, float, Tensor] = 1,
  728. dtype="float32",
  729. device: Optional[CompNode] = None,
  730. ) -> Tensor:
  731. r"""Returns a Tensor with values from start to end with adjacent interval step.
  732. :param start: starting value of the squence, shoule be scalar.
  733. :param end: ending value of the squence, shoule be scalar.
  734. :param step: gap between each pair of adjacent values. Default: 1
  735. :param dtype: result data type.
  736. :return: generated tensor.
  737. Examples:
  738. .. testcode::
  739. import numpy as np
  740. import megengine.functional as F
  741. a = F.arange(5)
  742. print(a.numpy())
  743. Outputs:
  744. Outputs:
  745. .. testoutput::
  746. [0. 1. 2. 3. 4.]
  747. """
  748. if end is None:
  749. start, end = 0, start
  750. if isinstance(start, Tensor):
  751. start = start.astype("float32")
  752. if isinstance(end, Tensor):
  753. end = end.astype("float32")
  754. if isinstance(step, Tensor):
  755. step = step.astype("float32")
  756. num = ceil(Tensor((end - start) / step, device=device))
  757. stop = start + step * (num - 1)
  758. result = linspace(start, stop, num, device=device)
  759. if np.dtype(dtype) == np.int32:
  760. return result.astype(dtype)
  761. return result
  762. def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor:
  763. r"""
  764. Returns split Tensor to Tensor list as offsets and shapes described,
  765. only used for parampack.
  766. :param inp: input tensor.
  767. :param offsets: offsets of outputs, length of 2 * n,
  768. while n is tensor nums you want to split,
  769. format `[begin0, end0, begin1, end1]`.
  770. :param shapes: tensor shapes of outputs.
  771. :return: split tensors.
  772. Examples:
  773. .. testcode::
  774. import numpy as np
  775. import megengine.functional as F
  776. from megengine import tensor
  777. a = tensor(np.ones((10,), np.int32))
  778. b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
  779. print(b.numpy())
  780. print(c.numpy())
  781. Outputs:
  782. .. testoutput::
  783. [1]
  784. [[1 1 1]
  785. [1 1 1]
  786. [1 1 1]]
  787. """
  788. op = builtin.ParamPackSplit()
  789. op.offsets = offsets
  790. op.shapes = shapes
  791. return apply(op, inp)
  792. def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor:
  793. r"""
  794. Returns concat Tensor, only used for parampack.
  795. :param inps: input tensors.
  796. :param offsets: device value of offsets.
  797. :param offsets_val: offsets of inputs, length of 2 * n,
  798. format [begin0, end0, begin1, end1].
  799. :return: concat tensors
  800. Examples:
  801. .. testcode::
  802. import numpy as np
  803. import megengine.functional as F
  804. from megengine import tensor
  805. a = tensor(np.ones((1,), np.int32))
  806. b = tensor(np.ones((3, 3), np.int32))
  807. offsets_val = [0, 1, 1, 10]
  808. offsets = tensor(offsets_val, np.int32)
  809. c = F.param_pack_concat([a, b], offsets, offsets_val)
  810. print(c.numpy())
  811. Outputs:
  812. .. testoutput::
  813. [1 1 1 1 1 1 1 1 1 1]
  814. """
  815. op = builtin.ParamPackConcat()
  816. op.offsets = offsets_val
  817. return apply(op, *inps, offsets)[0]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台