You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vision.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Iterable, Optional, Tuple, Union
  10. from ..core._imperative_rt.core2 import apply
  11. from ..core.ops import builtin
  12. from ..core.tensor import megbrain_graph, utils
  13. from ..core.tensor.utils import astensor1d
  14. from ..jit.tracing import is_tracing
  15. from ..tensor import Tensor
  16. from .elemwise import floor
  17. from .math import argsort
  18. from .tensor import broadcast_to, concat, expand_dims, reshape
  19. def cvt_color(inp: Tensor, mode: str = ""):
  20. r"""
  21. Convert images from one format to another
  22. :param inp: input images.
  23. :param mode: format mode.
  24. :return: convert result.
  25. Examples:
  26. .. testcode::
  27. import numpy as np
  28. import megengine as mge
  29. import megengine.functional as F
  30. x = mge.tensor(np.array([[[[-0.58675045, 1.7526233, 0.10702174]]]]).astype(np.float32))
  31. y = F.vision.cvt_color(x, mode="RGB2GRAY")
  32. print(y.numpy())
  33. Outputs:
  34. .. testoutput::
  35. [[[[0.86555195]]]]
  36. """
  37. mode = mode.upper()
  38. assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color"
  39. mode = getattr(builtin.CvtColor.Mode, mode)
  40. assert isinstance(mode, builtin.CvtColor.Mode)
  41. op = builtin.CvtColor(mode=mode)
  42. (out,) = apply(op, inp)
  43. return out
  44. def roi_pooling(
  45. inp: Tensor,
  46. rois: Tensor,
  47. output_shape: Union[int, tuple, list],
  48. mode: str = "max",
  49. scale: float = 1.0,
  50. ) -> Tensor:
  51. """
  52. Applies roi pooling on input feature.
  53. :param inp: tensor that represents the input feature, `(N, C, H, W)` images.
  54. :param rois: `(K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy.
  55. :param output_shape: `(height, width)` of output rois feature.
  56. :param mode: "max" or "average", use max/average align just like max/average pooling. Default: "max"
  57. :param scale: scale the input boxes by this number. Default: 1.0
  58. :return: `(K, C, output_shape[0], output_shape[1])` feature of rois.
  59. Examples:
  60. .. testcode::
  61. import numpy as np
  62. from megengine import tensor
  63. import megengine.functional as F
  64. np.random.seed(42)
  65. inp = tensor(np.random.randn(1, 1, 128, 128))
  66. rois = tensor(np.random.random((4, 5)))
  67. y = F.vision.roi_pooling(inp, rois, (2, 2))
  68. print(y.numpy()[0].round(decimals=4))
  69. Outputs:
  70. .. testoutput::
  71. [[[-0.1383 -0.1383]
  72. [-0.5035 -0.5035]]]
  73. """
  74. assert mode.lower() in ["max", "average"], "only max/average mode is supported"
  75. if isinstance(output_shape, int):
  76. output_shape = (output_shape, output_shape)
  77. op = builtin.ROIPooling(mode=mode, scale=scale)
  78. inp, rois = utils.convert_inputs(inp, rois)
  79. result, _ = apply(
  80. op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device)
  81. )
  82. return result
  83. def correlation(
  84. data1: Tensor,
  85. data2: Tensor,
  86. kernel_size: int = 1,
  87. max_displacement: int = 1,
  88. stride1: int = 1,
  89. stride2: int = 1,
  90. pad_size: int = 0,
  91. is_multiply: bool = True,
  92. ) -> Tensor:
  93. """ Applies correlation to inputs.
  94. :param data1: Input data1 to the correlation. format must be nchw
  95. :param data2: Input data2 to the correlation. format must be nchw
  96. :param kernel_size: (int (non-negative), optional, default=1) – kernel size for Correlation must be an odd number
  97. :param max_displacement: (int (non-negative), optional, default=1) – Max displacement of Correlation
  98. :param stride1: (int (non-negative), optional, default=1) – stride1 quantize data1 globally
  99. :param stride2: (int (non-negative), optional, default=1) – stride2 quantize data2 within the neighborhood centered around data1
  100. :param pad_size: (int (non-negative), optional, default=0) – pad for Correlation
  101. :param is_multiply: (boolean, optional, default=True) – operation type is either multiplication or absolute difference
  102. """
  103. op = builtin.Correlation(
  104. format="NCHW",
  105. kernel_size=kernel_size,
  106. max_displacement=max_displacement,
  107. stride1=stride1,
  108. stride2=stride2,
  109. pad_size=pad_size,
  110. is_multiply=is_multiply,
  111. )
  112. result, *_ = apply(op, data1, data2)
  113. return result
  114. def roi_align(
  115. inp: Tensor,
  116. rois: Tensor,
  117. output_shape: Union[int, tuple, list],
  118. mode: str = "average",
  119. spatial_scale: float = 1.0,
  120. sample_points: Union[int, tuple, list] = 2,
  121. aligned: bool = True,
  122. ) -> Tensor:
  123. """
  124. Applies roi align on input feature.
  125. :param inp: tensor that represents the input feature, shape is `(N, C, H, W)`.
  126. :param rois: `(N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``.
  127. :param output_shape: `(height, width)` shape of output rois feature.
  128. :param mode: "max" or "average", use max/average align just like max/average pooling. Default: "average"
  129. :param spatial_scale: scale the input boxes by this number. Default: 1.0
  130. :param sample_points: number of inputs samples to take for each output sample.
  131. 0 to take samples densely. Default: 2
  132. :param aligned: wheather to align the input feature, with `aligned=True`,
  133. we first appropriately scale the ROI and then shift it by -0.5. Default: True
  134. :return: output tensor.
  135. Examples:
  136. .. testcode::
  137. import numpy as np
  138. from megengine import tensor
  139. import megengine.functional as F
  140. np.random.seed(42)
  141. inp = tensor(np.random.randn(1, 1, 128, 128))
  142. rois = tensor(np.random.random((4, 5)))
  143. y = F.vision.roi_align(inp, rois, (2, 2))
  144. print(y.numpy()[0].round(decimals=4))
  145. Outputs:
  146. .. testoutput::
  147. [[[0.175 0.175 ]
  148. [0.1359 0.1359]]]
  149. """
  150. mode = mode.lower()
  151. assert mode in ["max", "average"], "only max/average mode is supported"
  152. if isinstance(output_shape, int):
  153. output_shape = (output_shape, output_shape)
  154. pooled_height, pooled_width = output_shape
  155. if isinstance(sample_points, int):
  156. sample_points = (sample_points, sample_points)
  157. sample_height, sample_width = sample_points
  158. offset = 0.5 if aligned else 0.0
  159. op = builtin.ROIAlign(
  160. mode=mode,
  161. format="NCHW",
  162. spatial_scale=spatial_scale,
  163. offset=offset,
  164. pooled_height=pooled_height,
  165. pooled_width=pooled_width,
  166. sample_height=sample_height,
  167. sample_width=sample_width,
  168. )
  169. inp, rois = utils.convert_inputs(inp, rois)
  170. result, *_ = apply(op, inp, rois)
  171. return result
  172. def nms(
  173. boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None
  174. ) -> Tensor:
  175. r"""
  176. Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU).
  177. :param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
  178. :param iou_thresh: IoU threshold for overlapping.
  179. :param scores: tensor of shape `(N,)`, the score of boxes.
  180. :param max_output: the maximum number of boxes to keep; it is optional if this operator is not traced
  181. otherwise it required to be specified; if it is not specified, all boxes are kept.
  182. :return: indices of the elements that have been kept by NMS.
  183. Examples:
  184. .. testcode::
  185. import numpy as np
  186. from megengine import tensor
  187. import megengine.functional as F
  188. x = np.zeros((100,4))
  189. np.random.seed(42)
  190. x[:,:2] = np.random.rand(100,2)*20
  191. x[:,2:] = np.random.rand(100,2)*20 + 100
  192. scores = tensor(np.random.rand(100))
  193. inp = tensor(x)
  194. result = F.vision.nms(inp, scores, iou_thresh=0.7)
  195. print(result.numpy())
  196. Outputs:
  197. .. testoutput::
  198. [75 69]
  199. """
  200. assert (
  201. boxes.ndim == 2 and boxes.shape[1] == 4
  202. ), "the expected shape of boxes is (N, 4)"
  203. assert scores.ndim == 1, "the expected shape of scores is (N,)"
  204. assert (
  205. boxes.shape[0] == scores.shape[0]
  206. ), "number of boxes and scores are not matched"
  207. boxes = boxes.detach()
  208. scores = scores.detach()
  209. sorted_idx = argsort(scores, descending=True)
  210. boxes = boxes[sorted_idx]
  211. if is_tracing():
  212. assert (
  213. max_output is not None and max_output > 0
  214. ), "max_output should be specified under tracing"
  215. if max_output is None:
  216. max_output = boxes.shape[0]
  217. op = builtin.NMSKeep(iou_thresh, max_output)
  218. inp = utils.convert_inputs(boxes.reshape(1, -1, 4))
  219. indices, count = apply(op, *inp)
  220. indices = indices[0][: count[0]]
  221. keep_inds = sorted_idx[indices]
  222. return keep_inds
  223. def remap(
  224. inp: Tensor,
  225. map_xy: Tensor,
  226. border_mode: str = "replicate",
  227. scalar: float = 0.0,
  228. interp_mode: str = "linear",
  229. ) -> Tensor:
  230. r"""
  231. Applies remap transformation to batched 2D images.
  232. The input images are transformed to the output images by the tensor map_xy.
  233. The output's H and W are same as map_xy's H and W.
  234. :param inp: input image
  235. :param map_xy: (batch, oh, ow, 2) transformation matrix
  236. :param border_mode: pixel extrapolation method.
  237. Default: "replicate". Currently also support "constant", "reflect",
  238. "reflect_101", "wrap".
  239. :param scalar: value used in case of a constant border. Default: 0
  240. :param interp_mode: interpolation methods.
  241. Default: "linear". Currently only support "linear" mode.
  242. :return: output tensor.
  243. Examples:
  244. .. testcode::
  245. import numpy as np
  246. from megengine import tensor
  247. import megengine.functional as F
  248. inp_shape = (1, 1, 4, 4)
  249. inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  250. map_xy_shape = (1, 2, 2, 2)
  251. map_xy = tensor(np.array([[[1., 0.],[0., 1.]],
  252. [[0., 1.],[0., 1.]]],
  253. dtype=np.float32).reshape(map_xy_shape))
  254. out = F.vision.remap(inp, map_xy)
  255. print(out.numpy())
  256. Outputs:
  257. .. testoutput::
  258. [[[[1. 4.]
  259. [4. 4.]]]]
  260. """
  261. op = builtin.Remap(
  262. imode=interp_mode, border_type=border_mode, format="NCHW", scalar=scalar
  263. )
  264. assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
  265. (result,) = apply(op, inp, map_xy)
  266. return result
  267. def warp_affine(
  268. inp: Tensor,
  269. mat: Tensor,
  270. out_shape: Union[Tuple[int, int], int, Tensor],
  271. border_mode: str = "replicate",
  272. border_val: float = 0.0,
  273. format: str = "NHWC",
  274. interp_mode: str = "linear",
  275. ) -> Tensor:
  276. """
  277. Batched affine transform on 2D images.
  278. :param inp: input image.
  279. :param mat: `(batch, 2, 3)` transformation matrix.
  280. :param out_shape: output tensor shape.
  281. :param border_mode: pixel extrapolation method.
  282. Default: "wrap". Currently "constant", "reflect",
  283. "reflect_101", "isolated", "wrap", "replicate", "transparent" are supported.
  284. :param border_val: value used in case of a constant border. Default: 0
  285. :param format: "NHWC" as default based on historical concerns,
  286. "NCHW" is also supported. Default: "NHWC".
  287. :param interp_mode: interpolation methods. Could be "linear", "nearest", "cubic", "area".
  288. Default: "linear".
  289. :return: output tensor.
  290. .. note::
  291. Here all available options for params are listed,
  292. however it does not mean that you can use all the combinations.
  293. On different platforms, different combinations are supported.
  294. """
  295. op = builtin.WarpAffine(
  296. border_mode=border_mode,
  297. border_val=border_val,
  298. format=format,
  299. imode=interp_mode,
  300. )
  301. out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device)
  302. (result,) = apply(op, inp, mat, out_shape)
  303. return result
  304. def warp_perspective(
  305. inp: Tensor,
  306. mat: Tensor,
  307. out_shape: Union[Tuple[int, int], int, Tensor],
  308. mat_idx: Optional[Union[Iterable[int], Tensor]] = None,
  309. border_mode: str = "replicate",
  310. border_val: float = 0.0,
  311. format: str = "NCHW",
  312. interp_mode: str = "linear",
  313. ) -> Tensor:
  314. r"""
  315. Applies perspective transformation to batched 2D images.
  316. The input images are transformed to the output images by the transformation matrix:
  317. .. math::
  318. \text{output}(n, c, h, w) = \text{input} \left( n, c,
  319. \frac{M_{00}h + M_{01}w + M_{02}}{M_{20}h + M_{21}w + M_{22}},
  320. \frac{M_{10}h + M_{11}w + M_{12}}{M_{20}h + M_{21}w + M_{22}}
  321. \right)
  322. Optionally, we can set `mat_idx` to assign different transformations to the same image,
  323. otherwise the input images and transformations should be one-to-one correnspondence.
  324. :param inp: input image.
  325. :param mat: `(batch, 3, 3)` transformation matrix.
  326. :param out_shape: `(h, w)` size of the output image.
  327. :param mat_idx: `(batch, )` image batch idx assigned to each matrix. Default: None
  328. :param border_mode: pixel extrapolation method.
  329. Default: "replicate". Currently also support "constant", "reflect",
  330. "reflect_101", "wrap".
  331. :param border_val: value used in case of a constant border. Default: 0
  332. :param format: "NHWC" is also supported. Default: "NCHW".
  333. :param interp_mode: interpolation methods.
  334. Default: "linear". Currently only support "linear" mode.
  335. :return: output tensor.
  336. .. note::
  337. The transformation matrix is the inverse of that used by `cv2.warpPerspective`.
  338. Examples:
  339. .. testcode::
  340. import numpy as np
  341. from megengine import tensor
  342. import megengine.functional as F
  343. inp_shape = (1, 1, 4, 4)
  344. x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  345. M_shape = (1, 3, 3)
  346. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  347. M = tensor(np.array([[1., 0., 1.],
  348. [0., 1., 1.],
  349. [0., 0., 1.]], dtype=np.float32).reshape(M_shape))
  350. out = F.vision.warp_perspective(x, M, (2, 2))
  351. print(out.numpy())
  352. Outputs:
  353. .. testoutput::
  354. [[[[ 5. 6.]
  355. [ 9. 10.]]]]
  356. """
  357. op = builtin.WarpPerspective(
  358. imode=interp_mode, bmode=border_mode, format=format, border_val=border_val
  359. )
  360. inp, mat = utils.convert_inputs(inp, mat)
  361. out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device)
  362. if mat_idx is not None:
  363. mat_idx = astensor1d(mat_idx, inp, dtype="int32", device=inp.device)
  364. (result,) = apply(op, inp, mat, mat_idx, out_shape)
  365. return result
  366. (result,) = apply(op, inp, mat, out_shape)
  367. return result
  368. def interpolate(
  369. inp: Tensor,
  370. size: Optional[Union[int, Tuple[int, int]]] = None,
  371. scale_factor: Optional[Union[float, Tuple[float, float]]] = None,
  372. mode: str = "bilinear",
  373. align_corners: Optional[bool] = None,
  374. ) -> Tensor:
  375. r"""
  376. Down/up samples the input tensor to either the given size or with the given scale_factor. ``size`` can not coexist with ``scale_factor``.
  377. :param inp: input tensor.
  378. :param size: size of the output tensor. Default: None
  379. :param scale_factor: scaling factor of the output tensor. Default: None
  380. :param mode: interpolation methods, acceptable values are:
  381. "bilinear", "linear". Default: "bilinear"
  382. :param align_corners: This only has an effect when `mode`
  383. is "bilinear" or "linear". Geometrically, we consider the pixels of the input
  384. and output as squares rather than points. If set to ``True``, the input
  385. and output tensors are aligned by the center points of their corner
  386. pixels, preserving the values at the corner pixels. If set to ``False``,
  387. the input and output tensors are aligned by the corner points of their
  388. corner pixels, and the interpolation uses edge value padding for
  389. out-of-boundary values, making this operation *independent* of input size
  390. :return: output tensor.
  391. Examples:
  392. .. testcode::
  393. import numpy as np
  394. from megengine import tensor
  395. import megengine.functional as F
  396. x = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  397. out = F.vision.interpolate(x, [4, 4], align_corners=False)
  398. print(out.numpy())
  399. out2 = F.vision.interpolate(x, scale_factor=2.)
  400. np.testing.assert_allclose(out.numpy(), out2.numpy())
  401. Outputs:
  402. .. testoutput::
  403. [[[[1. 1.25 1.75 2. ]
  404. [1.5 1.75 2.25 2.5 ]
  405. [2.5 2.75 3.25 3.5 ]
  406. [3. 3.25 3.75 4. ]]]]
  407. """
  408. mode = mode.lower()
  409. if mode not in ["bilinear", "linear"]:
  410. raise ValueError("interpolate only support linear or bilinear mode")
  411. if mode not in ["bilinear", "linear"]:
  412. if align_corners is not None:
  413. raise ValueError(
  414. "align_corners option can only be set in the bilinear/linear interpolating mode"
  415. )
  416. else:
  417. if align_corners is None:
  418. align_corners = False
  419. if (
  420. size is not None
  421. and scale_factor is None
  422. and not align_corners
  423. and mode == "bilinear"
  424. and inp.ndim in [4, 5]
  425. ):
  426. # fastpath for interpolate
  427. op = builtin.Resize(imode="linear", format="NCHW")
  428. shape = astensor1d(size, inp, dtype="int32", device=inp.device)
  429. (result,) = apply(op, inp, shape)
  430. return result
  431. if mode == "linear":
  432. inp = expand_dims(inp, 3)
  433. if inp.ndim != 4:
  434. raise ValueError("shape of input tensor must correspond to the operartion mode")
  435. if size is None:
  436. if scale_factor is None:
  437. raise ValueError("scale_factor must not be None when size is None")
  438. if isinstance(scale_factor, (float, int)):
  439. scale_factor = float(scale_factor)
  440. if mode == "linear":
  441. scale_factor = (scale_factor, float(1))
  442. else:
  443. scale_factor = (scale_factor, scale_factor)
  444. else:
  445. if mode == "linear":
  446. raise ValueError(
  447. "under linear mode, scale_factor can only be single value"
  448. )
  449. assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )"
  450. assert isinstance(scale_factor[0], float) and isinstance(
  451. scale_factor[1], float
  452. ), "scale_factor must be float type"
  453. dsize = tuple(
  454. floor(
  455. Tensor(
  456. inp.shape[i + 2] * scale_factor[i],
  457. dtype="float32",
  458. device=inp.device,
  459. )
  460. )
  461. for i in range(2)
  462. )
  463. dsize = concat([dsize[0], dsize[1]], axis=0)
  464. else:
  465. if scale_factor is not None:
  466. raise ValueError("scale_factor must be None when size is provided")
  467. if isinstance(size, int):
  468. size = (size, 1)
  469. else:
  470. if mode == "linear":
  471. raise ValueError("under linear mode, size can only be single value")
  472. dsize = size
  473. oh, ow = dsize[0], dsize[1]
  474. ih, iw = inp.shape[2], inp.shape[3]
  475. if align_corners:
  476. hscale = (ih - 1.0) / (oh - 1.0)
  477. wscale = 1.0 * iw / ow
  478. if mode != "linear":
  479. wscale = (iw - 1.0) / (ow - 1.0)
  480. row0 = concat(
  481. [wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0
  482. ).reshape(1, 3)
  483. row1 = concat(
  484. [
  485. Tensor(0, dtype="float32", device=inp.device),
  486. hscale,
  487. Tensor(0, dtype="float32", device=inp.device),
  488. ],
  489. axis=0,
  490. ).reshape(1, 3)
  491. weight = concat(
  492. [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
  493. axis=0,
  494. ).reshape(1, 3, 3)
  495. weight = broadcast_to(weight, (inp.shape[0], 3, 3))
  496. else:
  497. hscale = 1.0 * ih / oh
  498. wscale = 1.0 * iw / ow
  499. row0 = concat(
  500. [wscale, Tensor(0, dtype="float32", device=inp.device), 0.5 * wscale - 0.5],
  501. axis=0,
  502. ).reshape(1, 3)
  503. row1 = concat(
  504. [Tensor(0, dtype="float32", device=inp.device), hscale, 0.5 * hscale - 0.5],
  505. axis=0,
  506. ).reshape(1, 3)
  507. weight = concat(
  508. [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
  509. axis=0,
  510. ).reshape(1, 3, 3)
  511. weight = broadcast_to(weight, (inp.shape[0], 3, 3))
  512. weight = weight.astype("float32")
  513. ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
  514. if mode == "linear":
  515. ret = reshape(ret, ret.shape[0:3])
  516. return ret
  517. def nvof(src: Tensor, precision: int = 1) -> Tensor:
  518. r"""
  519. Implements NVIDIA Optical Flow SDK.
  520. :src shape: input tensor with shape (n, t, h, w, c4).
  521. :src dtype: uint8.
  522. :param precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST.
  523. :output shape: (n, t-1, h//4, w//4, c2).
  524. :output dtype: int16.
  525. .. code-block:: python
  526. import numpy as np
  527. from megengine import tensor
  528. import megengine.functional as F
  529. x = np.random.random_integers(0, 255, (1,2,224,244,4)).astype("uint8")
  530. src = tensor(x)
  531. result = F.nn.nvof(src, precision=1)
  532. print(result.numpy())
  533. """
  534. assert src.ndim == 5 and src.shape[4] == 4
  535. src = src.detach()
  536. op = builtin.NvOf(precision=precision)
  537. return apply(op, src)[0]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台