You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. from functools import lru_cache
  12. from typing import Optional, Sequence, Tuple, Union
  13. from ..core._imperative_rt.core2 import apply, dtype_promotion
  14. from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
  15. from ..core._trace_option import use_symbolic_shape
  16. from ..core.ops import builtin
  17. from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
  18. from ..core.ops.special import Const
  19. from ..core.tensor import amp
  20. from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
  21. from ..jit import exclude_from_trace
  22. from ..tensor import Tensor
  23. from .debug_param import get_execution_strategy
  24. from .elemwise import clip, minimum
  25. from .tensor import broadcast_to, concat, expand_dims, squeeze
  26. __all__ = [
  27. "argmax",
  28. "argmin",
  29. "argsort",
  30. "dot",
  31. "isinf",
  32. "isnan",
  33. "matinv",
  34. "matmul",
  35. "max",
  36. "mean",
  37. "min",
  38. "norm",
  39. "normalize",
  40. "prod",
  41. "sign",
  42. "sort",
  43. "std",
  44. "sum",
  45. "svd",
  46. "topk",
  47. "var",
  48. ]
  49. def isnan(inp: Tensor) -> Tensor:
  50. r"""
  51. Returns a new tensor representing if each element is ``NaN`` or not.
  52. :param inp: input tensor.
  53. :return: result tensor.
  54. Examples:
  55. .. testcode::
  56. from megengine import tensor
  57. import megengine.functional as F
  58. x = tensor([1, float("nan"), 0])
  59. print(F.isnan(x).numpy())
  60. Outputs:
  61. .. testoutput::
  62. [False True False]
  63. """
  64. return inp != inp
  65. def isinf(inp: Tensor) -> Tensor:
  66. r"""
  67. Returns a new tensor representing if each element is ``Inf`` or not.
  68. :param inp: input tensor.
  69. :return: result tensor.
  70. Examples:
  71. .. testcode::
  72. from megengine import tensor
  73. import megengine.functional as F
  74. x = tensor([1, float("inf"), 0])
  75. print(F.isinf(x).numpy())
  76. Outputs:
  77. .. testoutput::
  78. [False True False]
  79. """
  80. return abs(inp).astype("float32") == float("inf")
  81. def sign(inp: Tensor):
  82. r"""
  83. Returns a new tensor representing the sign of each element in input tensor.
  84. :param: input tensor.
  85. :return: the sign of input tensor.
  86. Examples:
  87. .. testcode::
  88. from megengine import tensor
  89. import megengine.functional as F
  90. x = tensor([1, -1, 0])
  91. print(F.sign(x).numpy())
  92. Outputs:
  93. .. testoutput::
  94. [ 1 -1 0]
  95. """
  96. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  97. def sum(
  98. inp: Tensor,
  99. axis: Optional[Union[int, Sequence[int]]] = None,
  100. keepdims: bool = False,
  101. ) -> Tensor:
  102. r"""
  103. Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  104. reduce over all of them.
  105. :param inp: input tensor.
  106. :param axis: dimension to reduce. If None, all dimensions will be reduced.
  107. Default: None
  108. :param keepdims: whether the output tensor has axis retained or not.
  109. Default: False
  110. :return: output tensor.
  111. Examples:
  112. .. testcode::
  113. import numpy as np
  114. from megengine import tensor
  115. import megengine.functional as F
  116. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  117. out = F.sum(x)
  118. print(out.numpy())
  119. Outputs:
  120. .. testoutput::
  121. 21
  122. """
  123. return inp.sum(axis=axis, keepdims=keepdims)
  124. def prod(
  125. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  126. ) -> Tensor:
  127. r"""
  128. Returns the product of input tensor along given axis. If axis is a list of dimensions,
  129. reduce over all of them.
  130. :param inp: input tensor.
  131. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  132. :param keepdims: whether the output tensor has axis retained or not. Default: False
  133. :return: output tensor.
  134. Examples:
  135. .. testcode::
  136. import numpy as np
  137. from megengine import tensor
  138. import megengine.functional as F
  139. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  140. out = F.prod(x)
  141. print(out.numpy())
  142. Outputs:
  143. .. testoutput::
  144. 720
  145. """
  146. return inp.prod(axis=axis, keepdims=keepdims)
  147. def mean(
  148. inp: Tensor,
  149. axis: Optional[Union[int, Sequence[int]]] = None,
  150. keepdims: bool = False,
  151. ) -> Tensor:
  152. """
  153. Returns the mean value of input tensor along
  154. given axis. If axis is a list of dimensions,
  155. reduce over all of them.
  156. :param inp: input tensor.
  157. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  158. :param keepdims: whether the output tensor has axis retained or not. Default: False
  159. :return: output tensor.
  160. Examples:
  161. .. testcode::
  162. import numpy as np
  163. from megengine import tensor
  164. import megengine.functional as F
  165. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  166. out = F.mean(x)
  167. print(out.numpy())
  168. Outputs:
  169. .. testoutput::
  170. 3.5
  171. """
  172. return inp.mean(axis=axis, keepdims=keepdims)
  173. def var(
  174. inp: Tensor,
  175. axis: Optional[Union[int, Sequence[int]]] = None,
  176. keepdims: bool = False,
  177. ) -> Tensor:
  178. """
  179. Returns the variance value of input tensor along
  180. given axis. If axis is a list of dimensions,
  181. reduce over all of them.
  182. :param inp: input tensor.
  183. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  184. :param keepdims: whether the output tensor has axis retained or not. Default: False
  185. :return: output tensor.
  186. Examples:
  187. .. testcode::
  188. import numpy as np
  189. from megengine import tensor
  190. import megengine.functional as F
  191. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  192. out = F.var(data)
  193. print(out.numpy().round(decimals=4))
  194. Outputs:
  195. .. testoutput::
  196. 2.9167
  197. """
  198. if axis is None:
  199. m = mean(inp, axis=axis, keepdims=False)
  200. else:
  201. m = mean(inp, axis=axis, keepdims=True)
  202. v = inp - m
  203. return mean(v ** 2, axis=axis, keepdims=keepdims)
  204. def std(
  205. inp: Tensor,
  206. axis: Optional[Union[int, Sequence[int]]] = None,
  207. keepdims: bool = False,
  208. ) -> Tensor:
  209. """
  210. Returns the standard deviation of input tensor along
  211. given axis. If axis is a list of dimensions,
  212. reduce over all of them.
  213. :param inp: input tensor.
  214. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  215. :param keepdims: whether the output tensor has axis retained or not. Default: False
  216. :return: output tensor.
  217. Examples:
  218. .. testcode::
  219. import numpy as np
  220. from megengine import tensor
  221. import megengine.functional as F
  222. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  223. out = F.std(data, axis=1)
  224. print(out.numpy().round(decimals=4))
  225. Outputs:
  226. .. testoutput::
  227. [0.8165 0.8165]
  228. """
  229. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  230. def min(
  231. inp: Tensor,
  232. axis: Optional[Union[int, Sequence[int]]] = None,
  233. keepdims: bool = False,
  234. ) -> Tensor:
  235. r"""
  236. Returns the min value of input tensor along
  237. given axis. If axis is a list of dimensions,
  238. reduce over all of them.
  239. :param inp: input tensor.
  240. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  241. :param keepdims: whether the output tensor has axis retained or not. Default: False
  242. :return: output tensor.
  243. Examples:
  244. .. testcode::
  245. import numpy as np
  246. from megengine import tensor
  247. import megengine.functional as F
  248. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  249. out = F.min(x)
  250. print(out.numpy())
  251. Outputs:
  252. .. testoutput::
  253. 1
  254. """
  255. return inp.min(axis=axis, keepdims=keepdims)
  256. def max(
  257. inp: Tensor,
  258. axis: Optional[Union[int, Sequence[int]]] = None,
  259. keepdims: bool = False,
  260. ) -> Tensor:
  261. r"""
  262. Returns the max value of the input tensor along
  263. given axis. If axis is a list of dimensions,
  264. reduce over all of them.
  265. :param inp: input tensor.
  266. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  267. :param keepdims: whether the output tensor has axis retained or not. Default: False
  268. :return: output tensor.
  269. Examples:
  270. .. testcode::
  271. import numpy as np
  272. from megengine import tensor
  273. import megengine.functional as F
  274. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  275. out = F.max(x)
  276. print(out.numpy())
  277. Outputs:
  278. .. testoutput::
  279. 6
  280. """
  281. return inp.max(axis=axis, keepdims=keepdims)
  282. def norm(
  283. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  284. ):
  285. """
  286. Calculates ``p``-norm of input tensor along
  287. given axis.
  288. :param inp: input tensor.
  289. :param ord: power of value applied to inp. Default: 2
  290. :param axis: dimension to reduce. If None, input must be a vector. Default: None
  291. :param keepdims: whether the output tensor has axis retained or not. Default: False
  292. :return: output tensor.
  293. Examples:
  294. .. testcode::
  295. import numpy as np
  296. from megengine import tensor
  297. import megengine.functional as F
  298. x = tensor(np.arange(-3, 3, dtype=np.float32))
  299. out = F.norm(x)
  300. print(out.numpy().round(decimals=4))
  301. Outputs:
  302. .. testoutput::
  303. 4.3589
  304. """
  305. if axis is None:
  306. if inp.ndim != 1:
  307. raise TypeError("axis is required unless input is a vector")
  308. if ord is None:
  309. ord = 2
  310. if ord == 0:
  311. return sum(inp != 0, axis=axis, keepdims=keepdims)
  312. if ord == math.inf:
  313. return max(abs(inp))
  314. if ord == -math.inf:
  315. return min(abs(inp))
  316. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  317. def argmin(
  318. inp: Tensor,
  319. axis: Optional[Union[int, Sequence[int]]] = None,
  320. keepdims: bool = False,
  321. ) -> Tensor:
  322. r"""
  323. Returns the indices of the minimum values along
  324. given axis. If axis is a list of dimensions,
  325. reduce over all of them.
  326. :param inp: input tensor.
  327. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  328. :param keepdims: whether the output tensor has axis retained or not. Default: False
  329. :return: output tensor.
  330. Examples:
  331. .. testcode::
  332. import numpy as np
  333. from megengine import tensor
  334. import megengine.functional as F
  335. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  336. out = F.argmin(x)
  337. print(out.numpy())
  338. Outputs:
  339. .. testoutput::
  340. 0
  341. """
  342. if axis is None:
  343. assert not keepdims, "can not set axis=None and keepdims=True"
  344. inp = inp.flatten()
  345. axis = 0
  346. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  347. if isinstance(axis, collections.abc.Iterable):
  348. for ai in axis:
  349. op = builtin.Argmin(axis=ai)
  350. (inp,) = apply(op, inp)
  351. if not keepdims:
  352. inp = squeeze(inp, ai)
  353. return inp
  354. op = builtin.Argmin(axis=axis)
  355. (result,) = apply(op, inp)
  356. if not keepdims:
  357. result = squeeze(result, axis)
  358. return result
  359. def argmax(
  360. inp: Tensor,
  361. axis: Optional[Union[int, Sequence[int]]] = None,
  362. keepdims: bool = False,
  363. ) -> Tensor:
  364. r"""
  365. Returns the indices of the maximum values along
  366. given axis. If axis is a list of dimensions,
  367. reduce over all of them.
  368. :param inp: input tensor.
  369. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  370. :param keepdims: whether the output tensor has axis retained or not. Default: False
  371. :return: output tensor.
  372. Examples:
  373. .. testcode::
  374. import numpy as np
  375. from megengine import tensor
  376. import megengine.functional as F
  377. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  378. out = F.argmax(x)
  379. print(out.numpy())
  380. Outputs:
  381. .. testoutput::
  382. 5
  383. """
  384. if axis is None:
  385. assert not keepdims, "can not set axis=None and keepdims=True"
  386. inp = inp.flatten()
  387. axis = 0
  388. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  389. if isinstance(axis, collections.abc.Iterable):
  390. for ai in axis:
  391. op = builtin.Argmax(axis=ai)
  392. (inp,) = apply(op, inp)
  393. if not keepdims:
  394. inp = squeeze(inp, ai)
  395. return inp
  396. op = builtin.Argmax(axis=axis)
  397. (result,) = apply(op, inp)
  398. if not keepdims:
  399. result = squeeze(result, axis)
  400. return result
  401. def normalize(
  402. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  403. ) -> Tensor:
  404. r"""
  405. Performs :math:`L_p` normalization of input tensor along
  406. given axis.
  407. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  408. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  409. .. math::
  410. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  411. :param inp: input tensor.
  412. :param ord: power of value applied to input tensor. Default: 2
  413. :param axis: dimension to reduce.If None, input must be a vector. Default: None
  414. :param eps: a small value to avoid division by zero. Default: 1e-12
  415. :return: normalized output tensor.
  416. """
  417. if axis is None:
  418. return inp / clip(norm(inp, ord, axis), lower=eps)
  419. else:
  420. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  421. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  422. r"""
  423. Returns the indices that would sort the input tensor.
  424. :param inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  425. :param descending: sort in descending order, where the largest comes first. Default: False
  426. :return: indices of int32 indicates how to sort the input.
  427. Examples:
  428. .. testcode::
  429. import numpy as np
  430. from megengine import tensor
  431. import megengine.functional as F
  432. x = tensor(np.array([1,2], dtype=np.float32))
  433. indices = F.argsort(x)
  434. print(indices.numpy())
  435. Outputs:
  436. .. testoutput::
  437. [0 1]
  438. """
  439. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  440. if descending:
  441. order = "descending"
  442. else:
  443. order = "ascending"
  444. op = builtin.Argsort(order=order)
  445. if len(inp.shape) == 1:
  446. inp = inp.reshape(1, -1)
  447. _, result = apply(op, inp)
  448. return result[0]
  449. _, result = apply(op, inp)
  450. return result
  451. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  452. r"""
  453. Returns sorted tensor and the indices would sort the input tensor.
  454. :param inp: input tensor. If it's 2d, the result would be sorted by row.
  455. :param descending: sort in descending order, where the largest comes first. Default: False
  456. :return: tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  457. Examples:
  458. .. testcode::
  459. import numpy as np
  460. from megengine import tensor
  461. import megengine.functional as F
  462. x = tensor(np.array([1,2], dtype=np.float32))
  463. out, indices = F.sort(x)
  464. print(out.numpy())
  465. Outputs:
  466. .. testoutput::
  467. [1. 2.]
  468. """
  469. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  470. if descending:
  471. order = "descending"
  472. else:
  473. order = "ascending"
  474. op = builtin.Argsort(order=order)
  475. if len(inp.shape) == 1:
  476. inp = inp.reshape(1, -1)
  477. tns, ind = apply(op, inp)
  478. return tns[0], ind[0]
  479. tns, ind = apply(op, inp)
  480. return tns, ind
  481. def topk(
  482. inp: Tensor,
  483. k: int,
  484. descending: bool = False,
  485. kth_only: bool = False,
  486. no_sort: bool = False,
  487. ) -> Tuple[Tensor, Tensor]:
  488. r"""
  489. Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
  490. :param inp: input tensor. If input tensor is 2d, each row will be sorted.
  491. :param k: number of elements needed.
  492. :param descending: if True, return the largest elements instead. Default: False
  493. :param kth_only: if True, only the k-th element will be returned. Default: False
  494. :param no_sort: if True, the returned elements can be unordered. Default: False
  495. :return: tuple of two tensors ``(topk_tensor, indices_of_int32)``
  496. Examples:
  497. .. testcode::
  498. import numpy as np
  499. from megengine import tensor
  500. import megengine.functional as F
  501. x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  502. top, indices = F.topk(x, 5)
  503. print(top.numpy(), indices.numpy())
  504. Outputs:
  505. .. testoutput::
  506. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  507. """
  508. if descending:
  509. k = -k
  510. if kth_only:
  511. mode = "kth_only"
  512. elif no_sort:
  513. mode = "value_idx_nosort"
  514. else:
  515. mode = "value_idx_sorted"
  516. op = builtin.TopK(mode=mode)
  517. if not isinstance(k, Tensor):
  518. (k,) = Const(k, dtype="int32", device=inp.device)()
  519. if len(inp.shape) == 1:
  520. if kth_only:
  521. (tns,) = apply(op, expand_dims(inp, 0), k)
  522. # FIXME:
  523. # could use a dedicated kernel
  524. # gradient may be routed to other indices if k-th value is not unique
  525. ind = argmax((tns == inp).astype("int8"))
  526. tns = squeeze(tns, 0)
  527. else:
  528. tns, ind = apply(op, expand_dims(inp, 0), k)
  529. tns = squeeze(tns, 0)
  530. ind = squeeze(ind, 0)
  531. else:
  532. if kth_only:
  533. (tns,) = apply(op, inp, k)
  534. # FIXME: same as above
  535. ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1)
  536. else:
  537. tns, ind = apply(op, inp, k)
  538. return tns, ind
  539. def matinv(inp: Tensor) -> Tensor:
  540. """
  541. Computes the inverse of a batch of matrices; input must has shape [..., n, n].
  542. :param inp: input tensor.
  543. :return: output tensor.
  544. Examples:
  545. .. testcode::
  546. import numpy as np
  547. from megengine import tensor
  548. import megengine.functional as F
  549. data = tensor([[1.0, 0.0], [1.0, 1.0]])
  550. out = F.matinv(data)
  551. print(out.numpy())
  552. Outputs:
  553. .. testoutput::
  554. [[ 1. 0.]
  555. [-1. 1.]]
  556. """
  557. (result,) = apply(builtin.MatrixInverse(), inp)
  558. return result
  559. class _Hashable:
  560. def __init__(self, value) -> None:
  561. self.value = value
  562. def __hash__(self) -> int:
  563. return hash(str(self.value))
  564. def __eq__(self, o: object) -> bool:
  565. if not isinstance(o, _Hashable):
  566. return False
  567. return self.value == o.value
  568. @lru_cache(maxsize=None)
  569. def _get_extentedMatrixMulOp(
  570. device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
  571. ):
  572. @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
  573. def extentedMatrixMulOp(inputs, f, c):
  574. assert len(inputs) == 2
  575. inp1, inp2 = inputs
  576. _dim1, _dim2 = dim1, dim2
  577. def build_shape_head(shape, idx=-1):
  578. # shape[:idx]
  579. return f(
  580. builtin.Subtensor(items=[[0, False, True, False, False]]),
  581. shape,
  582. c(idx, "int32"),
  583. )
  584. def build_shape_tail(shape, idx=-1):
  585. # shape[idx:]
  586. return f(
  587. builtin.Subtensor(items=[[0, True, False, False, False]]),
  588. shape,
  589. c(idx, "int32"),
  590. )
  591. remove_row, remove_col = False, False
  592. if _dim1 == 1:
  593. _dim1 = 2
  594. remove_row = True
  595. if _dim2 == 1:
  596. _dim2 = 2
  597. remove_col = True
  598. if remove_row:
  599. inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
  600. if remove_col:
  601. inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
  602. shape1 = f(GetVarShape(), inp1)
  603. shape2 = f(GetVarShape(), inp2)
  604. if _dim1 > 2:
  605. inp1 = f(
  606. builtin.Reshape(),
  607. inp1,
  608. f(
  609. builtin.Concat(axis=0, comp_node=device),
  610. f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
  611. build_shape_tail(shape1),
  612. ),
  613. )
  614. if _dim2 > 2:
  615. inp2 = f(
  616. builtin.Reshape(),
  617. inp2,
  618. f(
  619. builtin.Concat(axis=0, comp_node=device),
  620. f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
  621. build_shape_tail(shape2),
  622. ),
  623. )
  624. op = builtin.MatrixMul(
  625. transposeA=transpose_a,
  626. transposeB=transpose_b,
  627. compute_mode=compute_mode,
  628. format=format,
  629. strategy=strategy.value,
  630. )
  631. result = f(op, inp1, inp2)
  632. result_shape = f(GetVarShape(), result)
  633. if _dim1 > 2:
  634. result = f(
  635. builtin.Reshape(),
  636. result,
  637. f(
  638. builtin.Concat(axis=0, comp_node=device),
  639. build_shape_head(shape1),
  640. build_shape_tail(result_shape),
  641. ),
  642. )
  643. if _dim2 > 2:
  644. result = f(
  645. builtin.Reshape(),
  646. result,
  647. f(
  648. builtin.Concat(axis=0, comp_node=device),
  649. build_shape_head(shape2),
  650. build_shape_tail(result_shape),
  651. ),
  652. )
  653. maxdim = _dim1 if _dim1 > _dim2 else _dim2
  654. if remove_row:
  655. result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
  656. if remove_col:
  657. result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
  658. return (result,), (True,)
  659. return extentedMatrixMulOp
  660. @lru_cache(maxsize=None)
  661. def _get_extentedBatchedMatrixMulOp(
  662. device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
  663. ):
  664. @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
  665. def extentedBatchedMatrixMulOp(inputs, f, c):
  666. assert len(inputs) == 2
  667. inp1, inp2 = inputs
  668. _dim1, _dim2 = dim1, dim2
  669. def build_shape_head(shape, idx=-2):
  670. # shape[:idx]
  671. return f(
  672. builtin.Subtensor(items=[[0, False, True, False, False]]),
  673. shape,
  674. c(idx, "int32"),
  675. )
  676. def build_shape_tail(shape, idx=-2):
  677. # shape[idx:]
  678. return f(
  679. builtin.Subtensor(items=[[0, True, False, False, False]]),
  680. shape,
  681. c(idx, "int32"),
  682. )
  683. remove_row, remove_col = False, False
  684. if _dim1 == 1:
  685. _dim1 = 2
  686. remove_row = True
  687. if _dim2 == 1:
  688. _dim2 = 2
  689. remove_col = True
  690. if remove_row:
  691. inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
  692. if remove_col:
  693. inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
  694. shape1 = f(GetVarShape(), inp1)
  695. shape2 = f(GetVarShape(), inp2)
  696. maxdim = _dim1 if _dim1 > _dim2 else _dim2
  697. if _dim1 > _dim2:
  698. # broadcast
  699. shape2 = f(
  700. builtin.Concat(axis=0, comp_node=device),
  701. build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
  702. shape2,
  703. )
  704. inp2 = f(builtin.Broadcast(), inp2, shape2)
  705. batch_shape = build_shape_head(shape1)
  706. if _dim2 > _dim1:
  707. # broadcast
  708. shape1 = f(
  709. builtin.Concat(axis=0, comp_node=device),
  710. build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
  711. shape1,
  712. )
  713. inp1 = f(builtin.Broadcast(), inp1, shape1)
  714. batch_shape = build_shape_head(shape2)
  715. if _dim1 == _dim2:
  716. batch_shape = build_shape_head(shape1)
  717. # compress inputs to 3d
  718. if maxdim > 3:
  719. inp1 = f(
  720. builtin.Reshape(),
  721. inp1,
  722. f(
  723. builtin.Concat(axis=0, comp_node=device),
  724. f(builtin.Reduce(mode="product", axis=0), batch_shape),
  725. build_shape_tail(shape1),
  726. ),
  727. )
  728. inp2 = f(
  729. builtin.Reshape(),
  730. inp2,
  731. f(
  732. builtin.Concat(axis=0, comp_node=device),
  733. f(builtin.Reduce(mode="product", axis=0), batch_shape),
  734. build_shape_tail(shape2),
  735. ),
  736. )
  737. op = builtin.BatchedMatrixMul(
  738. transposeA=transpose_a,
  739. transposeB=transpose_b,
  740. compute_mode=compute_mode,
  741. format=format,
  742. strategy=strategy.value,
  743. )
  744. result = f(op, inp1, inp2)
  745. if maxdim > 3:
  746. result = f(
  747. builtin.Reshape(),
  748. result,
  749. f(
  750. builtin.Concat(axis=0, comp_node=device),
  751. batch_shape,
  752. build_shape_tail(f(GetVarShape(), result)),
  753. ),
  754. )
  755. if remove_row:
  756. result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
  757. if remove_col:
  758. result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
  759. return (result,), (True,)
  760. return extentedBatchedMatrixMulOp
  761. def matmul(
  762. inp1: Tensor,
  763. inp2: Tensor,
  764. transpose_a=False,
  765. transpose_b=False,
  766. compute_mode="default",
  767. format="default",
  768. ) -> Tensor:
  769. """
  770. Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
  771. With different inputs dim, this function behaves differently:
  772. - Both 1-D tensor, simply forward to ``dot``.
  773. - Both 2-D tensor, normal matrix multiplication.
  774. - If one input tensor is 1-D, matrix vector multiplication.
  775. - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2,
  776. the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted.
  777. For example:
  778. - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
  779. - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
  780. - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
  781. :param inp1: first matrix to be multiplied.
  782. :param inp2: second matrix to be multiplied.
  783. :return: output tensor.
  784. Examples:
  785. .. testcode::
  786. import numpy as np
  787. from megengine import tensor
  788. import megengine.functional as F
  789. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  790. data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  791. out = F.matmul(data1, data2)
  792. print(out.numpy())
  793. Outputs:
  794. .. testoutput::
  795. [[10. 13.]
  796. [28. 40.]]
  797. """
  798. if amp._enabled:
  799. compute_mode = "float32"
  800. inp1, inp2 = cast_tensors(inp1, inp2)
  801. else:
  802. dtype = dtype_promotion(inp1, inp2)
  803. if inp1.dtype != dtype:
  804. inp1 = inp1.astype(dtype)
  805. if inp2.dtype != dtype:
  806. inp2 = inp2.astype(dtype)
  807. dim1, dim2 = inp1.ndim, inp2.ndim
  808. assert dim1 > 0 and dim2 > 0
  809. maxdim = dim1 if dim1 > dim2 else dim2
  810. if dim1 == 1 and dim2 == 1: # dispatch to Dot
  811. return dot(inp1, inp2)
  812. elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
  813. extentedMatrixMulOp = _get_extentedMatrixMulOp(
  814. inp1.device,
  815. inp1.dtype,
  816. dim1,
  817. dim2,
  818. transpose_a,
  819. transpose_b,
  820. compute_mode,
  821. format,
  822. strategy=_Hashable(get_execution_strategy()),
  823. )
  824. (result,) = apply(extentedMatrixMulOp(), inp1, inp2)
  825. return result
  826. else: # dispath to BatchedMatrixMul
  827. extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
  828. inp1.device,
  829. inp1.dtype,
  830. dim1,
  831. dim2,
  832. transpose_a,
  833. transpose_b,
  834. compute_mode,
  835. format,
  836. strategy=_Hashable(get_execution_strategy()),
  837. )
  838. (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
  839. return result
  840. def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
  841. """
  842. Computes dot-product of two vectors ``inp1`` and ``inp2``.
  843. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
  844. Refer to :func:`~.matmul` for more general usage.
  845. :param inp1: first vector.
  846. :param inp2: second vector.
  847. :return: output value.
  848. Examples:
  849. .. testcode::
  850. import numpy as np
  851. from megengine import tensor
  852. import megengine.functional as F
  853. data1 = tensor(np.arange(0, 6, dtype=np.float32))
  854. data2 = tensor(np.arange(0, 6, dtype=np.float32))
  855. out = F.dot(data1, data2)
  856. print(out.numpy())
  857. Outputs:
  858. .. testoutput::
  859. 55.
  860. """
  861. op = builtin.Dot()
  862. assert (
  863. inp1.ndim <= 1 and inp2.ndim <= 1
  864. ), "Input tensors for dot must be 1-dimensional or scalar"
  865. (result,) = apply(op, inp1, inp2)
  866. setscalar(result)
  867. return result
  868. def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
  869. """
  870. Computes the singular value decompositions of input matrix.
  871. :param inp: input matrix, must has shape `[..., M, N]`.
  872. :return: output matrices, `(U, sigma, V)`.
  873. Examples:
  874. .. testcode::
  875. import numpy as np
  876. from megengine import tensor
  877. import megengine.functional as F
  878. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3))
  879. _, y, _ = F.svd(x)
  880. print(y.numpy().round(decimals=3))
  881. Outputs:
  882. .. testoutput::
  883. [7.348 1. ]
  884. """
  885. op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
  886. U, sigma, V = apply(op, inp)
  887. return U, sigma, V
  888. def _has_inf(inp: Tensor) -> Tensor:
  889. """
  890. Check whether input contains infinite value.
  891. :param inp: a tensor to be checked.
  892. :return: a int32 scalar tensor, 0 for False and 1 for True.
  893. """
  894. op = builtin.CheckHasInf()
  895. (oup,) = apply(op, inp.reshape(-1).astype("float32"))
  896. oup._setscalar()
  897. return oup

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台