You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 32 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. from functools import lru_cache
  12. from typing import Optional, Sequence, Tuple, Union
  13. from ..core._imperative_rt.core2 import apply, dtype_promotion
  14. from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
  15. from ..core._trace_option import use_symbolic_shape
  16. from ..core.ops import builtin
  17. from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
  18. from ..core.ops.special import Const
  19. from ..core.tensor import amp
  20. from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
  21. from ..jit import exclude_from_trace
  22. from ..tensor import Tensor
  23. from .debug_param import get_execution_strategy
  24. from .elemwise import clip, minimum
  25. from .tensor import broadcast_to, concat, expand_dims, squeeze
  26. __all__ = [
  27. "argmax",
  28. "argmin",
  29. "argsort",
  30. "dot",
  31. "isinf",
  32. "isnan",
  33. "matinv",
  34. "matmul",
  35. "max",
  36. "mean",
  37. "min",
  38. "norm",
  39. "normalize",
  40. "prod",
  41. "sign",
  42. "sort",
  43. "std",
  44. "sum",
  45. "svd",
  46. "topk",
  47. "var",
  48. ]
  49. def isnan(inp: Tensor) -> Tensor:
  50. r"""Returns a new tensor representing if each element is ``NaN`` or not.
  51. Args:
  52. inp: input tensor.
  53. Returns:
  54. result tensor.
  55. Examples:
  56. .. testcode::
  57. from megengine import tensor
  58. import megengine.functional as F
  59. x = tensor([1, float("nan"), 0])
  60. print(F.isnan(x).numpy())
  61. Outputs:
  62. .. testoutput::
  63. [False True False]
  64. """
  65. return inp != inp
  66. def isinf(inp: Tensor) -> Tensor:
  67. r"""Returns a new tensor representing if each element is ``Inf`` or not.
  68. Args:
  69. inp: input tensor.
  70. Returns:
  71. result tensor.
  72. Examples:
  73. .. testcode::
  74. from megengine import tensor
  75. import megengine.functional as F
  76. x = tensor([1, float("inf"), 0])
  77. print(F.isinf(x).numpy())
  78. Outputs:
  79. .. testoutput::
  80. [False True False]
  81. """
  82. return abs(inp).astype("float32") == float("inf")
  83. def sign(inp: Tensor):
  84. r"""Returns a new tensor representing the sign of each element in input tensor.
  85. Args:
  86. inp: Tensor:
  87. Returns:
  88. the sign of input tensor.
  89. Examples:
  90. .. testcode::
  91. from megengine import tensor
  92. import megengine.functional as F
  93. x = tensor([1, -1, 0])
  94. print(F.sign(x).numpy())
  95. Outputs:
  96. .. testoutput::
  97. [ 1 -1 0]
  98. """
  99. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  100. def sum(
  101. inp: Tensor,
  102. axis: Optional[Union[int, Sequence[int]]] = None,
  103. keepdims: bool = False,
  104. ) -> Tensor:
  105. r"""Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  106. reduce over all of them.
  107. Args:
  108. inp: input tensor.
  109. axis: dimension to reduce. If None, all dimensions will be reduced.
  110. Default: None
  111. keepdims: whether the output tensor has axis retained or not.
  112. Default: False
  113. Returns:
  114. output tensor.
  115. Examples:
  116. .. testcode::
  117. import numpy as np
  118. from megengine import tensor
  119. import megengine.functional as F
  120. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  121. out = F.sum(x)
  122. print(out.numpy())
  123. Outputs:
  124. .. testoutput::
  125. 21
  126. """
  127. return inp.sum(axis=axis, keepdims=keepdims)
  128. def prod(
  129. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  130. ) -> Tensor:
  131. r"""Returns the product of input tensor along given axis. If axis is a list of dimensions,
  132. reduce over all of them.
  133. Args:
  134. inp: input tensor.
  135. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  136. keepdims: whether the output tensor has axis retained or not. Default: False
  137. Returns:
  138. output tensor.
  139. Examples:
  140. .. testcode::
  141. import numpy as np
  142. from megengine import tensor
  143. import megengine.functional as F
  144. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  145. out = F.prod(x)
  146. print(out.numpy())
  147. Outputs:
  148. .. testoutput::
  149. 720
  150. """
  151. return inp.prod(axis=axis, keepdims=keepdims)
  152. def mean(
  153. inp: Tensor,
  154. axis: Optional[Union[int, Sequence[int]]] = None,
  155. keepdims: bool = False,
  156. ) -> Tensor:
  157. r"""Returns the mean value of input tensor along
  158. given axis. If axis is a list of dimensions,
  159. reduce over all of them.
  160. Args:
  161. inp: input tensor.
  162. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  163. keepdims: whether the output tensor has axis retained or not. Default: False
  164. Returns:
  165. output tensor.
  166. Examples:
  167. .. testcode::
  168. import numpy as np
  169. from megengine import tensor
  170. import megengine.functional as F
  171. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  172. out = F.mean(x)
  173. print(out.numpy())
  174. Outputs:
  175. .. testoutput::
  176. 3.5
  177. """
  178. return inp.mean(axis=axis, keepdims=keepdims)
  179. def var(
  180. inp: Tensor,
  181. axis: Optional[Union[int, Sequence[int]]] = None,
  182. keepdims: bool = False,
  183. ) -> Tensor:
  184. r"""Returns the variance value of input tensor along
  185. given axis. If axis is a list of dimensions,
  186. reduce over all of them.
  187. Args:
  188. inp: input tensor.
  189. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  190. keepdims: whether the output tensor has axis retained or not. Default: False
  191. Returns:
  192. output tensor.
  193. Examples:
  194. .. testcode::
  195. import numpy as np
  196. from megengine import tensor
  197. import megengine.functional as F
  198. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  199. out = F.var(data)
  200. print(out.numpy().round(decimals=4))
  201. Outputs:
  202. .. testoutput::
  203. 2.9167
  204. """
  205. if axis is None:
  206. m = mean(inp, axis=axis, keepdims=False)
  207. else:
  208. m = mean(inp, axis=axis, keepdims=True)
  209. v = inp - m
  210. return mean(v ** 2, axis=axis, keepdims=keepdims)
  211. def std(
  212. inp: Tensor,
  213. axis: Optional[Union[int, Sequence[int]]] = None,
  214. keepdims: bool = False,
  215. ) -> Tensor:
  216. r"""Returns the standard deviation of input tensor along
  217. given axis. If axis is a list of dimensions,
  218. reduce over all of them.
  219. Args:
  220. inp: input tensor.
  221. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  222. keepdims: whether the output tensor has axis retained or not. Default: False
  223. Returns:
  224. output tensor.
  225. Examples:
  226. .. testcode::
  227. import numpy as np
  228. from megengine import tensor
  229. import megengine.functional as F
  230. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  231. out = F.std(data, axis=1)
  232. print(out.numpy().round(decimals=4))
  233. Outputs:
  234. .. testoutput::
  235. [0.8165 0.8165]
  236. """
  237. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  238. def min(
  239. inp: Tensor,
  240. axis: Optional[Union[int, Sequence[int]]] = None,
  241. keepdims: bool = False,
  242. ) -> Tensor:
  243. r"""Returns the min value of input tensor along
  244. given axis. If axis is a list of dimensions,
  245. reduce over all of them.
  246. Args:
  247. inp: input tensor.
  248. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  249. keepdims: whether the output tensor has axis retained or not. Default: False
  250. Returns:
  251. output tensor.
  252. Examples:
  253. .. testcode::
  254. import numpy as np
  255. from megengine import tensor
  256. import megengine.functional as F
  257. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  258. out = F.min(x)
  259. print(out.numpy())
  260. Outputs:
  261. .. testoutput::
  262. 1
  263. """
  264. return inp.min(axis=axis, keepdims=keepdims)
  265. def max(
  266. inp: Tensor,
  267. axis: Optional[Union[int, Sequence[int]]] = None,
  268. keepdims: bool = False,
  269. ) -> Tensor:
  270. r"""Returns the max value of the input tensor along
  271. given axis. If axis is a list of dimensions,
  272. reduce over all of them.
  273. Args:
  274. inp: input tensor.
  275. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  276. keepdims: whether the output tensor has axis retained or not. Default: False
  277. Returns:
  278. output tensor.
  279. Examples:
  280. .. testcode::
  281. import numpy as np
  282. from megengine import tensor
  283. import megengine.functional as F
  284. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  285. out = F.max(x)
  286. print(out.numpy())
  287. Outputs:
  288. .. testoutput::
  289. 6
  290. """
  291. return inp.max(axis=axis, keepdims=keepdims)
  292. def norm(
  293. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  294. ):
  295. r"""Calculates ``p``-norm of input tensor along
  296. given axis.
  297. Args:
  298. inp: input tensor.
  299. ord: power of value applied to inp. Default: 2
  300. axis: dimension to reduce. If None, input must be a vector. Default: None
  301. keepdims: whether the output tensor has axis retained or not. Default: False
  302. Returns:
  303. output tensor.
  304. Examples:
  305. .. testcode::
  306. import numpy as np
  307. from megengine import tensor
  308. import megengine.functional as F
  309. x = tensor(np.arange(-3, 3, dtype=np.float32))
  310. out = F.norm(x)
  311. print(out.numpy().round(decimals=4))
  312. Outputs:
  313. .. testoutput::
  314. 4.3589
  315. """
  316. if axis is None:
  317. if inp.ndim != 1:
  318. raise TypeError("axis is required unless input is a vector")
  319. if ord is None:
  320. ord = 2
  321. if ord == 0:
  322. return sum(inp != 0, axis=axis, keepdims=keepdims)
  323. if ord == math.inf:
  324. return max(abs(inp))
  325. if ord == -math.inf:
  326. return min(abs(inp))
  327. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  328. def argmin(
  329. inp: Tensor,
  330. axis: Optional[Union[int, Sequence[int]]] = None,
  331. keepdims: bool = False,
  332. ) -> Tensor:
  333. r"""Returns the indices of the minimum values along
  334. given axis. If axis is a list of dimensions,
  335. reduce over all of them.
  336. Args:
  337. inp: input tensor.
  338. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  339. keepdims: whether the output tensor has axis retained or not. Default: False
  340. Returns:
  341. output tensor.
  342. Examples:
  343. .. testcode::
  344. import numpy as np
  345. from megengine import tensor
  346. import megengine.functional as F
  347. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  348. out = F.argmin(x)
  349. print(out.numpy())
  350. Outputs:
  351. .. testoutput::
  352. 0
  353. """
  354. if axis is None:
  355. assert not keepdims, "can not set axis=None and keepdims=True"
  356. inp = inp.flatten()
  357. axis = 0
  358. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  359. if isinstance(axis, collections.abc.Iterable):
  360. for ai in axis:
  361. op = builtin.Argmin(axis=ai)
  362. (inp,) = apply(op, inp)
  363. if not keepdims:
  364. inp = squeeze(inp, ai)
  365. return inp
  366. op = builtin.Argmin(axis=axis)
  367. (result,) = apply(op, inp)
  368. if not keepdims:
  369. result = squeeze(result, axis)
  370. return result
  371. def argmax(
  372. inp: Tensor,
  373. axis: Optional[Union[int, Sequence[int]]] = None,
  374. keepdims: bool = False,
  375. ) -> Tensor:
  376. r"""Returns the indices of the maximum values along
  377. given axis. If axis is a list of dimensions,
  378. reduce over all of them.
  379. Args:
  380. inp: input tensor.
  381. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  382. keepdims: whether the output tensor has axis retained or not. Default: False
  383. Returns:
  384. output tensor.
  385. Examples:
  386. .. testcode::
  387. import numpy as np
  388. from megengine import tensor
  389. import megengine.functional as F
  390. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  391. out = F.argmax(x)
  392. print(out.numpy())
  393. Outputs:
  394. .. testoutput::
  395. 5
  396. """
  397. if axis is None:
  398. assert not keepdims, "can not set axis=None and keepdims=True"
  399. inp = inp.flatten()
  400. axis = 0
  401. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  402. if isinstance(axis, collections.abc.Iterable):
  403. for ai in axis:
  404. op = builtin.Argmax(axis=ai)
  405. (inp,) = apply(op, inp)
  406. if not keepdims:
  407. inp = squeeze(inp, ai)
  408. return inp
  409. op = builtin.Argmax(axis=axis)
  410. (result,) = apply(op, inp)
  411. if not keepdims:
  412. result = squeeze(result, axis)
  413. return result
  414. def normalize(
  415. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  416. ) -> Tensor:
  417. r"""Performs :math:`L_p` normalization of input tensor along
  418. given axis.
  419. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  420. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  421. .. math::
  422. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  423. Args:
  424. inp: input tensor.
  425. ord: power of value applied to input tensor. Default: 2
  426. axis: dimension to reduce.If None, input must be a vector. Default: None
  427. eps: a small value to avoid division by zero. Default: 1e-12
  428. Returns:
  429. normalized output tensor.
  430. """
  431. if axis is None:
  432. return inp / clip(norm(inp, ord, axis), lower=eps)
  433. else:
  434. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  435. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  436. r"""Returns the indices that would sort the input tensor.
  437. Args:
  438. inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  439. descending: sort in descending order, where the largest comes first. Default: False
  440. inp: Tensor:
  441. descending: bool:
  442. Returns:
  443. indices of int32 indicates how to sort the input.
  444. Examples:
  445. .. testcode::
  446. import numpy as np
  447. from megengine import tensor
  448. import megengine.functional as F
  449. x = tensor(np.array([1,2], dtype=np.float32))
  450. indices = F.argsort(x)
  451. print(indices.numpy())
  452. Outputs:
  453. .. testoutput::
  454. [0 1]
  455. """
  456. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  457. if descending:
  458. order = "descending"
  459. else:
  460. order = "ascending"
  461. op = builtin.Argsort(order=order)
  462. if len(inp.shape) == 1:
  463. inp = inp.reshape(1, -1)
  464. _, result = apply(op, inp)
  465. return result[0]
  466. _, result = apply(op, inp)
  467. return result
  468. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  469. r"""Returns sorted tensor and the indices would sort the input tensor.
  470. Args:
  471. inp: input tensor. If it's 2d, the result would be sorted by row.
  472. descending: sort in descending order, where the largest comes first. Default: False
  473. Returns:
  474. tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  475. Examples:
  476. .. testcode::
  477. import numpy as np
  478. from megengine import tensor
  479. import megengine.functional as F
  480. x = tensor(np.array([1,2], dtype=np.float32))
  481. out, indices = F.sort(x)
  482. print(out.numpy())
  483. Outputs:
  484. .. testoutput::
  485. [1. 2.]
  486. """
  487. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  488. if descending:
  489. order = "descending"
  490. else:
  491. order = "ascending"
  492. op = builtin.Argsort(order=order)
  493. if len(inp.shape) == 1:
  494. inp = inp.reshape(1, -1)
  495. tns, ind = apply(op, inp)
  496. return tns[0], ind[0]
  497. tns, ind = apply(op, inp)
  498. return tns, ind
  499. def topk(
  500. inp: Tensor,
  501. k: int,
  502. descending: bool = False,
  503. kth_only: bool = False,
  504. no_sort: bool = False,
  505. ) -> Tuple[Tensor, Tensor]:
  506. r"""Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
  507. Args:
  508. inp: input tensor. If input tensor is 2d, each row will be sorted.
  509. k: number of elements needed.
  510. descending: if True, return the largest elements instead. Default: False
  511. kth_only: if True, only the k-th element will be returned. Default: False
  512. no_sort: if True, the returned elements can be unordered. Default: False
  513. Returns:
  514. tuple of two tensors ``(topk_tensor, indices_of_int32)``
  515. Examples:
  516. .. testcode::
  517. import numpy as np
  518. from megengine import tensor
  519. import megengine.functional as F
  520. x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  521. top, indices = F.topk(x, 5)
  522. print(top.numpy(), indices.numpy())
  523. Outputs:
  524. .. testoutput::
  525. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  526. """
  527. if descending:
  528. k = -k
  529. if kth_only:
  530. mode = "kth_only"
  531. elif no_sort:
  532. mode = "value_idx_nosort"
  533. else:
  534. mode = "value_idx_sorted"
  535. op = builtin.TopK(mode=mode)
  536. if not isinstance(k, Tensor):
  537. (k,) = Const(k, dtype="int32", device=inp.device)()
  538. if len(inp.shape) == 1:
  539. if kth_only:
  540. (tns,) = apply(op, expand_dims(inp, 0), k)
  541. # FIXME:
  542. # could use a dedicated kernel
  543. # gradient may be routed to other indices if k-th value is not unique
  544. ind = argmax((tns == inp).astype("int8"))
  545. tns = squeeze(tns, 0)
  546. else:
  547. tns, ind = apply(op, expand_dims(inp, 0), k)
  548. tns = squeeze(tns, 0)
  549. ind = squeeze(ind, 0)
  550. else:
  551. if kth_only:
  552. (tns,) = apply(op, inp, k)
  553. # FIXME: same as above
  554. ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1)
  555. else:
  556. tns, ind = apply(op, inp, k)
  557. return tns, ind
  558. def matinv(inp: Tensor) -> Tensor:
  559. r"""Computes the inverse of a batch of matrices; input must has shape [..., n, n].
  560. Args:
  561. inp: input tensor.
  562. Returns:
  563. output tensor.
  564. Examples:
  565. .. testcode::
  566. import numpy as np
  567. from megengine import tensor
  568. import megengine.functional as F
  569. data = tensor([[1.0, 0.0], [1.0, 1.0]])
  570. out = F.matinv(data)
  571. print(out.numpy())
  572. Outputs:
  573. .. testoutput::
  574. [[ 1. 0.]
  575. [-1. 1.]]
  576. """
  577. (result,) = apply(builtin.MatrixInverse(), inp)
  578. return result
  579. class _Hashable:
  580. def __init__(self, value) -> None:
  581. self.value = value
  582. def __hash__(self) -> int:
  583. return hash(str(self.value))
  584. def __eq__(self, o: object) -> bool:
  585. if not isinstance(o, _Hashable):
  586. return False
  587. return self.value == o.value
  588. @lru_cache(maxsize=None)
  589. def _get_extentedMatrixMulOp(
  590. device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
  591. ):
  592. @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
  593. def extentedMatrixMulOp(inputs, f, c):
  594. assert len(inputs) == 2
  595. inp1, inp2 = inputs
  596. _dim1, _dim2 = dim1, dim2
  597. def build_shape_head(shape, idx=-1):
  598. # shape[:idx]
  599. return f(
  600. builtin.Subtensor(items=[[0, False, True, False, False]]),
  601. shape,
  602. c(idx, "int32"),
  603. )
  604. def build_shape_tail(shape, idx=-1):
  605. # shape[idx:]
  606. return f(
  607. builtin.Subtensor(items=[[0, True, False, False, False]]),
  608. shape,
  609. c(idx, "int32"),
  610. )
  611. remove_row, remove_col = False, False
  612. if _dim1 == 1:
  613. _dim1 = 2
  614. remove_row = True
  615. if _dim2 == 1:
  616. _dim2 = 2
  617. remove_col = True
  618. if remove_row:
  619. inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
  620. if remove_col:
  621. inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
  622. shape1 = f(GetVarShape(), inp1)
  623. shape2 = f(GetVarShape(), inp2)
  624. if _dim1 > 2:
  625. inp1 = f(
  626. builtin.Reshape(),
  627. inp1,
  628. f(
  629. builtin.Concat(axis=0, comp_node=device),
  630. f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
  631. build_shape_tail(shape1),
  632. ),
  633. )
  634. if _dim2 > 2:
  635. inp2 = f(
  636. builtin.Reshape(),
  637. inp2,
  638. f(
  639. builtin.Concat(axis=0, comp_node=device),
  640. f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
  641. build_shape_tail(shape2),
  642. ),
  643. )
  644. op = builtin.MatrixMul(
  645. transposeA=transpose_a,
  646. transposeB=transpose_b,
  647. compute_mode=compute_mode,
  648. format=format,
  649. strategy=strategy.value,
  650. )
  651. result = f(op, inp1, inp2)
  652. result_shape = f(GetVarShape(), result)
  653. if _dim1 > 2:
  654. result = f(
  655. builtin.Reshape(),
  656. result,
  657. f(
  658. builtin.Concat(axis=0, comp_node=device),
  659. build_shape_head(shape1),
  660. build_shape_tail(result_shape),
  661. ),
  662. )
  663. if _dim2 > 2:
  664. result = f(
  665. builtin.Reshape(),
  666. result,
  667. f(
  668. builtin.Concat(axis=0, comp_node=device),
  669. build_shape_head(shape2),
  670. build_shape_tail(result_shape),
  671. ),
  672. )
  673. maxdim = _dim1 if _dim1 > _dim2 else _dim2
  674. if remove_row:
  675. result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
  676. if remove_col:
  677. result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
  678. return (result,), (True,)
  679. return extentedMatrixMulOp
  680. @lru_cache(maxsize=None)
  681. def _get_extentedBatchedMatrixMulOp(
  682. device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
  683. ):
  684. @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
  685. def extentedBatchedMatrixMulOp(inputs, f, c):
  686. assert len(inputs) == 2
  687. inp1, inp2 = inputs
  688. _dim1, _dim2 = dim1, dim2
  689. def build_shape_head(shape, idx=-2):
  690. # shape[:idx]
  691. return f(
  692. builtin.Subtensor(items=[[0, False, True, False, False]]),
  693. shape,
  694. c(idx, "int32"),
  695. )
  696. def build_shape_tail(shape, idx=-2):
  697. # shape[idx:]
  698. return f(
  699. builtin.Subtensor(items=[[0, True, False, False, False]]),
  700. shape,
  701. c(idx, "int32"),
  702. )
  703. remove_row, remove_col = False, False
  704. if _dim1 == 1:
  705. _dim1 = 2
  706. remove_row = True
  707. if _dim2 == 1:
  708. _dim2 = 2
  709. remove_col = True
  710. if remove_row:
  711. inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
  712. if remove_col:
  713. inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
  714. shape1 = f(GetVarShape(), inp1)
  715. shape2 = f(GetVarShape(), inp2)
  716. maxdim = _dim1 if _dim1 > _dim2 else _dim2
  717. if _dim1 > _dim2:
  718. # broadcast
  719. shape2 = f(
  720. builtin.Concat(axis=0, comp_node=device),
  721. build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
  722. shape2,
  723. )
  724. inp2 = f(builtin.Broadcast(), inp2, shape2)
  725. batch_shape = build_shape_head(shape1)
  726. if _dim2 > _dim1:
  727. # broadcast
  728. shape1 = f(
  729. builtin.Concat(axis=0, comp_node=device),
  730. build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
  731. shape1,
  732. )
  733. inp1 = f(builtin.Broadcast(), inp1, shape1)
  734. batch_shape = build_shape_head(shape2)
  735. if _dim1 == _dim2:
  736. batch_shape = build_shape_head(shape1)
  737. # compress inputs to 3d
  738. if maxdim > 3:
  739. inp1 = f(
  740. builtin.Reshape(),
  741. inp1,
  742. f(
  743. builtin.Concat(axis=0, comp_node=device),
  744. f(builtin.Reduce(mode="product", axis=0), batch_shape),
  745. build_shape_tail(shape1),
  746. ),
  747. )
  748. inp2 = f(
  749. builtin.Reshape(),
  750. inp2,
  751. f(
  752. builtin.Concat(axis=0, comp_node=device),
  753. f(builtin.Reduce(mode="product", axis=0), batch_shape),
  754. build_shape_tail(shape2),
  755. ),
  756. )
  757. op = builtin.BatchedMatrixMul(
  758. transposeA=transpose_a,
  759. transposeB=transpose_b,
  760. compute_mode=compute_mode,
  761. format=format,
  762. strategy=strategy.value,
  763. )
  764. result = f(op, inp1, inp2)
  765. if maxdim > 3:
  766. result = f(
  767. builtin.Reshape(),
  768. result,
  769. f(
  770. builtin.Concat(axis=0, comp_node=device),
  771. batch_shape,
  772. build_shape_tail(f(GetVarShape(), result)),
  773. ),
  774. )
  775. if remove_row:
  776. result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
  777. if remove_col:
  778. result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
  779. return (result,), (True,)
  780. return extentedBatchedMatrixMulOp
  781. def matmul(
  782. inp1: Tensor,
  783. inp2: Tensor,
  784. transpose_a=False,
  785. transpose_b=False,
  786. compute_mode="default",
  787. format="default",
  788. ) -> Tensor:
  789. r"""Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
  790. With different inputs dim, this function behaves differently:
  791. * Both 1-D tensor, simply forward to ``dot``.
  792. * Both 2-D tensor, normal matrix multiplication.
  793. * If one input tensor is 1-D, matrix vector multiplication.
  794. * If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2,
  795. the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted.
  796. For example:
  797. * inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
  798. * inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
  799. * inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
  800. Args:
  801. inp1: first matrix to be multiplied.
  802. inp2: second matrix to be multiplied.
  803. Returns:
  804. output tensor.
  805. Examples:
  806. .. testcode::
  807. import numpy as np
  808. from megengine import tensor
  809. import megengine.functional as F
  810. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  811. data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  812. out = F.matmul(data1, data2)
  813. print(out.numpy())
  814. Outputs:
  815. .. testoutput::
  816. [[10. 13.]
  817. [28. 40.]]
  818. """
  819. if amp._enabled:
  820. compute_mode = "float32"
  821. inp1, inp2 = cast_tensors(inp1, inp2)
  822. else:
  823. dtype = dtype_promotion(inp1, inp2)
  824. if inp1.dtype != dtype:
  825. inp1 = inp1.astype(dtype)
  826. if inp2.dtype != dtype:
  827. inp2 = inp2.astype(dtype)
  828. dim1, dim2 = inp1.ndim, inp2.ndim
  829. assert dim1 > 0 and dim2 > 0
  830. maxdim = dim1 if dim1 > dim2 else dim2
  831. if dim1 == 1 and dim2 == 1: # dispatch to Dot
  832. return dot(inp1, inp2)
  833. elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
  834. extentedMatrixMulOp = _get_extentedMatrixMulOp(
  835. inp1.device,
  836. inp1.dtype,
  837. dim1,
  838. dim2,
  839. transpose_a,
  840. transpose_b,
  841. compute_mode,
  842. format,
  843. strategy=_Hashable(get_execution_strategy()),
  844. )
  845. (result,) = apply(extentedMatrixMulOp(), inp1, inp2)
  846. return result
  847. else: # dispath to BatchedMatrixMul
  848. extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
  849. inp1.device,
  850. inp1.dtype,
  851. dim1,
  852. dim2,
  853. transpose_a,
  854. transpose_b,
  855. compute_mode,
  856. format,
  857. strategy=_Hashable(get_execution_strategy()),
  858. )
  859. (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
  860. return result
  861. def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
  862. r"""Computes dot-product of two vectors ``inp1`` and ``inp2``.
  863. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
  864. Refer to :func:`~.matmul` for more general usage.
  865. Args:
  866. inp1: first vector.
  867. inp2: second vector.
  868. Returns:
  869. output value.
  870. Examples:
  871. .. testcode::
  872. import numpy as np
  873. from megengine import tensor
  874. import megengine.functional as F
  875. data1 = tensor(np.arange(0, 6, dtype=np.float32))
  876. data2 = tensor(np.arange(0, 6, dtype=np.float32))
  877. out = F.dot(data1, data2)
  878. print(out.numpy())
  879. Outputs:
  880. .. testoutput::
  881. 55.
  882. """
  883. op = builtin.Dot()
  884. assert (
  885. inp1.ndim <= 1 and inp2.ndim <= 1
  886. ), "Input tensors for dot must be 1-dimensional or scalar"
  887. (result,) = apply(op, inp1, inp2)
  888. setscalar(result)
  889. return result
  890. def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
  891. r"""Computes the singular value decompositions of input matrix.
  892. Args:
  893. inp: input matrix, must has shape `[..., M, N]`.
  894. Returns:
  895. output matrices, `(U, sigma, V)`.
  896. Examples:
  897. .. testcode::
  898. import numpy as np
  899. from megengine import tensor
  900. import megengine.functional as F
  901. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3))
  902. _, y, _ = F.svd(x)
  903. print(y.numpy().round(decimals=3))
  904. Outputs:
  905. .. testoutput::
  906. [7.348 1. ]
  907. """
  908. op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
  909. U, sigma, V = apply(op, inp)
  910. return U, sigma, V
  911. def _check_non_finite(inp: Tensor) -> Tensor:
  912. r"""Check whether input contains infinite or nan value.
  913. Args:
  914. inp: a tensor to be checked.
  915. Returns:
  916. a int32 scalar tensor, 0 for False and 1 for True.
  917. """
  918. op = builtin.CheckNonFinite()
  919. (oup,) = apply(op, inp.reshape(-1).astype("float32"))
  920. oup._setscalar()
  921. return oup

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台