You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vision.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Iterable, Optional, Tuple, Union
  10. import numpy as np
  11. from ..core import _config
  12. from ..core._imperative_rt.core2 import apply
  13. from ..core.ops import builtin
  14. from ..core.tensor import megbrain_graph, utils
  15. from ..core.tensor.utils import astensor1d
  16. from ..tensor import Tensor
  17. from .elemwise import floor
  18. from .math import argsort
  19. from .tensor import broadcast_to, concat, expand_dims, reshape, transpose
  20. __all__ = [
  21. "correlation",
  22. "cvt_color",
  23. "roi_pooling",
  24. "roi_align",
  25. "nms",
  26. "remap",
  27. "warp_affine",
  28. "warp_perspective",
  29. "interpolate",
  30. "nvof",
  31. ]
  32. def cvt_color(inp: Tensor, mode: str = ""):
  33. r"""Convert images from one format to another
  34. Args:
  35. inp: input images.
  36. mode: format mode.
  37. Returns:
  38. convert result.
  39. Examples:
  40. .. testcode::
  41. import numpy as np
  42. import megengine as mge
  43. import megengine.functional as F
  44. x = mge.tensor(np.array([[[[-0.58675045, 1.7526233, 0.10702174]]]]).astype(np.float32))
  45. y = F.vision.cvt_color(x, mode="RGB2GRAY")
  46. print(y.numpy())
  47. Outputs:
  48. .. testoutput::
  49. [[[[0.86555195]]]]
  50. """
  51. mode = mode.upper()
  52. assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color"
  53. mode = getattr(builtin.CvtColor.Mode, mode)
  54. assert isinstance(mode, builtin.CvtColor.Mode)
  55. op = builtin.CvtColor(mode=mode)
  56. (out,) = apply(op, inp)
  57. return out
  58. def roi_pooling(
  59. inp: Tensor,
  60. rois: Tensor,
  61. output_shape: Union[int, tuple, list],
  62. mode: str = "max",
  63. scale: float = 1.0,
  64. ) -> Tensor:
  65. r"""Applies roi pooling on input feature.
  66. Args:
  67. inp: tensor that represents the input feature, `(N, C, H, W)` images.
  68. rois: K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy.
  69. output_shape: height, width)` of output rois feature.
  70. mode: max" or "average", use max/average align just like max/average pooling. Default: "max"
  71. scale: scale the input boxes by this number. Default: 1.0
  72. Returns:
  73. ``K, C, output_shape[0], output_shape[1])`` feature of rois.
  74. Examples:
  75. .. testcode::
  76. import numpy as np
  77. from megengine import tensor
  78. import megengine.functional as F
  79. np.random.seed(42)
  80. inp = tensor(np.random.randn(1, 1, 128, 128))
  81. rois = tensor(np.random.random((4, 5)))
  82. y = F.vision.roi_pooling(inp, rois, (2, 2))
  83. print(y.numpy()[0].round(decimals=4))
  84. Outputs:
  85. .. testoutput::
  86. [[[-0.1383 -0.1383]
  87. [-0.5035 -0.5035]]]
  88. """
  89. assert mode.lower() in ["max", "average"], "only max/average mode is supported"
  90. if isinstance(output_shape, int):
  91. output_shape = (output_shape, output_shape)
  92. op = builtin.ROIPooling(mode=mode, scale=scale)
  93. result, _ = apply(
  94. op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device)
  95. )
  96. return result
  97. def correlation(
  98. data1: Tensor,
  99. data2: Tensor,
  100. kernel_size: int = 1,
  101. max_displacement: int = 1,
  102. stride1: int = 1,
  103. stride2: int = 1,
  104. pad_size: int = 0,
  105. is_multiply: bool = True,
  106. ) -> Tensor:
  107. r"""Applies correlation to inputs.
  108. Args:
  109. data1: Input data1 to the correlation. format must be nchw
  110. data2: Input data2 to the correlation. format must be nchw
  111. kernel_size: int (non-negative), optional, default=1) – kernel size for Correlation must be an odd number
  112. max_displacement: int (non-negative), optional, default=1) – Max displacement of Correlation
  113. stride1: int (non-negative), optional, default=1) – stride1 quantize data1 globally
  114. stride2: int (non-negative), optional, default=1) – stride2 quantize data2 within the neighborhood centered around data1
  115. pad_size: int (non-negative), optional, default=0) – pad for Correlation
  116. is_multiply: boolean, optional, default=True) – operation type is either multiplication or absolute difference
  117. """
  118. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  119. assert conv_format == "NCHW", "Currently correlation only support NCHW mode"
  120. op = builtin.Correlation(
  121. format=conv_format,
  122. kernel_size=kernel_size,
  123. max_displacement=max_displacement,
  124. stride1=stride1,
  125. stride2=stride2,
  126. pad_size=pad_size,
  127. is_multiply=is_multiply,
  128. )
  129. result, *_ = apply(op, data1, data2)
  130. return result
  131. def roi_align(
  132. inp: Tensor,
  133. rois: Tensor,
  134. output_shape: Union[int, tuple, list],
  135. mode: str = "average",
  136. spatial_scale: float = 1.0,
  137. sample_points: Union[int, tuple, list] = 2,
  138. aligned: bool = True,
  139. ) -> Tensor:
  140. r"""Applies roi align on input feature.
  141. Args:
  142. inp: tensor that represents the input feature, shape is `(N, C, H, W)`.
  143. rois: N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``.
  144. output_shape: height, width)` shape of output rois feature.
  145. mode: max" or "average", use max/average align just like max/average pooling. Default: "average"
  146. spatial_scale: scale the input boxes by this number. Default: 1.0
  147. sample_points: number of inputs samples to take for each output sample.
  148. 0 to take samples densely. Default: 2
  149. aligned: wheather to align the input feature, with `aligned=True`,
  150. we first appropriately scale the ROI and then shift it by -0.5. Default: True
  151. Returns:
  152. output tensor.
  153. Examples:
  154. .. testcode::
  155. import numpy as np
  156. from megengine import tensor
  157. import megengine.functional as F
  158. np.random.seed(42)
  159. inp = tensor(np.random.randn(1, 1, 128, 128))
  160. rois = tensor(np.random.random((4, 5)))
  161. y = F.vision.roi_align(inp, rois, (2, 2))
  162. print(y.numpy()[0].round(decimals=4))
  163. Outputs:
  164. .. testoutput::
  165. [[[0.175 0.175 ]
  166. [0.1359 0.1359]]]
  167. """
  168. if inp.dtype != np.float32:
  169. inp = inp.astype(np.float32)
  170. mode = mode.lower()
  171. assert mode in ["max", "average"], "only max/average mode is supported"
  172. if isinstance(output_shape, int):
  173. output_shape = (output_shape, output_shape)
  174. pooled_height, pooled_width = output_shape
  175. if isinstance(sample_points, int):
  176. sample_points = (sample_points, sample_points)
  177. sample_height, sample_width = sample_points
  178. offset = 0.5 if aligned else 0.0
  179. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  180. assert conv_format == "NCHW", "Currently roi_align only support NCHW mode"
  181. op = builtin.ROIAlign(
  182. mode=mode,
  183. format=conv_format,
  184. spatial_scale=spatial_scale,
  185. offset=offset,
  186. pooled_height=pooled_height,
  187. pooled_width=pooled_width,
  188. sample_height=sample_height,
  189. sample_width=sample_width,
  190. )
  191. result, *_ = apply(op, inp, rois)
  192. return result
  193. def nms(
  194. boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None
  195. ) -> Tensor:
  196. r"""Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU).
  197. Args:
  198. boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
  199. iou_thresh: IoU threshold for overlapping.
  200. scores: tensor of shape `(N,)`, the score of boxes.
  201. max_output: the maximum number of boxes to keep; it is optional if this operator is not traced
  202. otherwise it required to be specified; if it is not specified, all boxes are kept.
  203. Returns:
  204. indices of the elements that have been kept by NMS, sorted by scores.
  205. Note:
  206. max_output should be specified and should have valid positive value under tracing.
  207. Examples:
  208. .. testcode::
  209. import numpy as np
  210. from megengine import tensor
  211. import megengine.functional as F
  212. x = np.zeros((100,4))
  213. np.random.seed(42)
  214. x[:,:2] = np.random.rand(100,2)*20
  215. x[:,2:] = np.random.rand(100,2)*20 + 100
  216. scores = tensor(np.random.rand(100))
  217. inp = tensor(x)
  218. result = F.vision.nms(inp, scores, iou_thresh=0.7)
  219. print(result.numpy())
  220. Outputs:
  221. .. testoutput::
  222. [75 69]
  223. """
  224. assert (
  225. boxes.ndim == 2 and boxes.shape[1] == 4
  226. ), "the expected shape of boxes is (N, 4)"
  227. assert scores.ndim == 1, "the expected shape of scores is (N,)"
  228. assert (
  229. boxes.shape[0] == scores.shape[0]
  230. ), "number of boxes and scores are not matched"
  231. boxes = boxes.detach()
  232. scores = scores.detach()
  233. sorted_idx = argsort(scores, descending=True)
  234. boxes = boxes[sorted_idx]
  235. if max_output is None:
  236. max_output = boxes.shape[0]
  237. op = builtin.NMSKeep(iou_thresh, max_output)
  238. inp = (boxes.reshape(1, -1, 4),)
  239. indices, count = apply(op, *inp)
  240. indices = indices[0][: count[0]]
  241. keep_inds = sorted_idx[indices]
  242. return keep_inds
  243. def remap(
  244. inp: Tensor,
  245. map_xy: Tensor,
  246. border_mode: str = "replicate",
  247. scalar: float = 0.0,
  248. interp_mode: str = "linear",
  249. ) -> Tensor:
  250. r"""Applies remap transformation to batched 2D images.
  251. The input images are transformed to the output images by the tensor map_xy.
  252. The output's H and W are same as map_xy's H and W.
  253. Args:
  254. inp: input image
  255. map_xy: batch, oh, ow, 2) transformation matrix
  256. border_mode: pixel extrapolation method.
  257. Default: "replicate". Currently also support "constant", "reflect",
  258. "reflect_101", "wrap".
  259. scalar: value used in case of a constant border. Default: 0
  260. interp_mode: interpolation methods.
  261. Default: "linear". Currently only support "linear" mode.
  262. Returns:
  263. output tensor.
  264. Examples:
  265. .. testcode::
  266. import numpy as np
  267. from megengine import tensor
  268. import megengine.functional as F
  269. inp_shape = (1, 1, 4, 4)
  270. inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  271. map_xy_shape = (1, 2, 2, 2)
  272. map_xy = tensor(np.array([[[1., 0.],[0., 1.]],
  273. [[0., 1.],[0., 1.]]],
  274. dtype=np.float32).reshape(map_xy_shape))
  275. out = F.vision.remap(inp, map_xy)
  276. print(out.numpy())
  277. Outputs:
  278. .. testoutput::
  279. [[[[1. 4.]
  280. [4. 4.]]]]
  281. """
  282. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  283. op = builtin.Remap(
  284. imode=interp_mode, border_type=border_mode, format=conv_format, scalar=scalar
  285. )
  286. assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
  287. (result,) = apply(op, inp, map_xy)
  288. return result
  289. def warp_affine(
  290. inp: Tensor,
  291. mat: Tensor,
  292. out_shape: Union[Tuple[int, int], int, Tensor],
  293. border_mode: str = "replicate",
  294. border_val: float = 0.0,
  295. format: str = "NHWC",
  296. interp_mode: str = "linear",
  297. ) -> Tensor:
  298. r"""Batched affine transform on 2D images.
  299. Args:
  300. inp: input image.
  301. mat: batch, 2, 3)` transformation matrix.
  302. out_shape: output tensor shape.
  303. border_mode: pixel extrapolation method.
  304. Default: "wrap". Currently "constant", "reflect",
  305. "reflect_101", "isolated", "wrap", "replicate", "transparent" are supported.
  306. border_val: value used in case of a constant border. Default: 0
  307. format: NHWC" as default based on historical concerns,
  308. "NCHW" is also supported. Default: "NHWC".
  309. interp_mode: interpolation methods. Could be "linear", "nearest", "cubic", "area".
  310. Default: "linear".
  311. Returns:
  312. output tensor.
  313. Note:
  314. Here all available options for params are listed,
  315. however it does not mean that you can use all the combinations.
  316. On different platforms, different combinations are supported.
  317. """
  318. conv_format = _config._get_actual_op_param(format, _config.__conv_format)
  319. op = builtin.WarpAffine(
  320. border_mode=border_mode,
  321. border_val=border_val,
  322. format=conv_format,
  323. imode=interp_mode,
  324. )
  325. out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device)
  326. (result,) = apply(op, inp, mat, out_shape)
  327. return result
  328. def warp_perspective(
  329. inp: Tensor,
  330. mat: Tensor,
  331. out_shape: Union[Tuple[int, int], int, Tensor],
  332. mat_idx: Optional[Union[Iterable[int], Tensor]] = None,
  333. border_mode: str = "replicate",
  334. border_val: float = 0.0,
  335. format: str = "NCHW",
  336. interp_mode: str = "linear",
  337. ) -> Tensor:
  338. r"""Applies perspective transformation to batched 2D images.
  339. The input images are transformed to the output images by the transformation matrix:
  340. .. math::
  341. \text{output}(n, c, h, w) = \text{input} \left( n, c,
  342. \frac{M_{00}w + M_{01}h + M_{02}}{M_{20}w + M_{21}h + M_{22}},
  343. \frac{M_{10}w + M_{11}h + M_{12}}{M_{20}w + M_{21}h + M_{22}}
  344. \right)
  345. Optionally, we can set `mat_idx` to assign different transformations to the same image,
  346. otherwise the input images and transformations should be one-to-one correnspondence.
  347. Args:
  348. inp: input image.
  349. mat: batch, 3, 3)` transformation matrix.
  350. out_shape: h, w)` size of the output image.
  351. mat_idx: batch, )` image batch idx assigned to each matrix. Default: None
  352. border_mode: pixel extrapolation method.
  353. Default: "replicate". Currently also support "constant", "reflect",
  354. "reflect_101", "wrap".
  355. border_val: value used in case of a constant border. Default: 0
  356. format: NHWC" is also supported. Default: "NCHW".
  357. interp_mode: interpolation methods.
  358. Default: "linear". Currently only support "linear" mode.
  359. Returns:
  360. output tensor.
  361. Note:
  362. The transformation matrix is the inverse of that used by `cv2.warpPerspective`.
  363. Examples:
  364. .. testcode::
  365. import numpy as np
  366. from megengine import tensor
  367. import megengine.functional as F
  368. inp_shape = (1, 1, 4, 4)
  369. x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  370. M_shape = (1, 3, 3)
  371. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  372. M = tensor(np.array([[1., 0., 1.],
  373. [0., 1., 1.],
  374. [0., 0., 1.]], dtype=np.float32).reshape(M_shape))
  375. out = F.vision.warp_perspective(x, M, (2, 2))
  376. print(out.numpy())
  377. Outputs:
  378. .. testoutput::
  379. [[[[ 5. 6.]
  380. [ 9. 10.]]]]
  381. """
  382. if inp.dtype == np.float32:
  383. mat = mat.astype("float32")
  384. if inp.dtype == np.float16:
  385. inp = inp.astype("float32")
  386. conv_format = _config._get_actual_op_param(format, _config.__conv_format)
  387. op = builtin.WarpPerspective(
  388. imode=interp_mode, bmode=border_mode, format=conv_format, border_val=border_val
  389. )
  390. out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device)
  391. if mat_idx is not None:
  392. mat_idx = astensor1d(mat_idx, inp, dtype="int32", device=inp.device)
  393. (result,) = apply(op, inp, mat, mat_idx, out_shape)
  394. return result
  395. (result,) = apply(op, inp, mat, out_shape)
  396. return result
  397. def interpolate(
  398. inp: Tensor,
  399. size: Optional[Union[int, Tuple[int, int]]] = None,
  400. scale_factor: Optional[Union[float, Tuple[float, float]]] = None,
  401. mode: str = "bilinear",
  402. align_corners: Optional[bool] = None,
  403. ) -> Tensor:
  404. r"""Down/up samples the input tensor to either the given size or with the given scale_factor. ``size`` can not coexist with ``scale_factor``.
  405. Args:
  406. inp: input tensor.
  407. size: size of the output tensor. Default: None
  408. scale_factor: scaling factor of the output tensor. Default: None
  409. mode: interpolation methods, acceptable values are:
  410. "bilinear", "linear", "bicubic" and "nearest". Default: "bilinear"
  411. align_corners: This only has an effect when `mode`
  412. is "bilinear" or "linear". Geometrically, we consider the pixels of the input
  413. and output as squares rather than points. If set to ``True``, the input
  414. and output tensors are aligned by the center points of their corner
  415. pixels, preserving the values at the corner pixels. If set to ``False``,
  416. the input and output tensors are aligned by the corner points of their
  417. corner pixels, and the interpolation uses edge value padding for
  418. out-of-boundary values, making this operation *independent* of input size
  419. Returns:
  420. output tensor.
  421. Examples:
  422. .. testcode::
  423. import numpy as np
  424. from megengine import tensor
  425. import megengine.functional as F
  426. x = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  427. out = F.vision.interpolate(x, [4, 4], align_corners=False)
  428. print(out.numpy())
  429. out2 = F.vision.interpolate(x, scale_factor=2.)
  430. np.testing.assert_allclose(out.numpy(), out2.numpy())
  431. Outputs:
  432. .. testoutput::
  433. [[[[1. 1.25 1.75 2. ]
  434. [1.5 1.75 2.25 2.5 ]
  435. [2.5 2.75 3.25 3.5 ]
  436. [3. 3.25 3.75 4. ]]]]
  437. """
  438. mode = mode.lower()
  439. if mode not in ["bilinear", "linear", "bicubic", "nearest"]:
  440. raise ValueError("unsupported interpolate mode: {}".format(mode))
  441. if mode not in ["bilinear", "linear"]:
  442. if align_corners is not None:
  443. raise ValueError(
  444. "align_corners option can only be set in the bilinear/linear interpolating mode"
  445. )
  446. else:
  447. if align_corners is None:
  448. align_corners = False
  449. if mode == "linear":
  450. inp = expand_dims(inp, 3)
  451. if inp.ndim != 4:
  452. raise ValueError("shape of input tensor must correspond to the operartion mode")
  453. def get_dsize(scale_factor):
  454. if isinstance(scale_factor, (float, int)):
  455. scale_factor = float(scale_factor)
  456. if mode == "linear":
  457. scale_factor = (scale_factor, float(1))
  458. else:
  459. scale_factor = (scale_factor, scale_factor)
  460. else:
  461. if mode == "linear":
  462. raise ValueError(
  463. "under linear mode, scale_factor can only be single value"
  464. )
  465. assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )"
  466. assert isinstance(scale_factor[0], float) and isinstance(
  467. scale_factor[1], float
  468. ), "scale_factor must be float type"
  469. dsize = tuple(
  470. floor(
  471. Tensor(
  472. inp.shape[i + 2] * scale_factor[i],
  473. dtype="float32",
  474. device=inp.device,
  475. )
  476. )
  477. for i in range(2)
  478. )
  479. dsize = concat([dsize[0], dsize[1]], axis=0)
  480. return dsize
  481. if size is None:
  482. if scale_factor is None:
  483. raise ValueError("scale_factor must not be None when size is None")
  484. dsize = get_dsize(scale_factor)
  485. else:
  486. if scale_factor is not None:
  487. raise ValueError("scale_factor must be None when size is provided")
  488. if isinstance(size, int):
  489. size = (size, 1)
  490. else:
  491. if mode == "linear":
  492. raise ValueError("under linear mode, size can only be single value")
  493. dsize = size
  494. if not align_corners:
  495. # fastpath for interpolate
  496. mode_map = {
  497. "linear": "linear",
  498. "bilinear": "linear",
  499. "nearest": "nearest",
  500. "bicubic": "cubic",
  501. }
  502. if inp.dtype == np.float16:
  503. inp = inp.astype("float32")
  504. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  505. assert conv_format == "NCHW", "Currently resize only support NCHW mode"
  506. op = builtin.Resize(imode=mode_map[mode], format=conv_format)
  507. shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
  508. (ret,) = apply(op, inp, shape)
  509. else:
  510. assert mode in [
  511. "linear",
  512. "bilinear",
  513. ], "align_corners only support linear or bilinear mode"
  514. oh, ow = dsize[0], dsize[1]
  515. ih, iw = inp.shape[2], inp.shape[3]
  516. hscale = (ih - 1.0) / (oh - 1.0)
  517. wscale = 1.0 * iw / ow
  518. if mode != "linear":
  519. wscale = (iw - 1.0) / (ow - 1.0)
  520. row0 = concat(
  521. [wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0
  522. ).reshape(1, 3)
  523. row1 = concat(
  524. [
  525. Tensor(0, dtype="float32", device=inp.device),
  526. hscale,
  527. Tensor(0, dtype="float32", device=inp.device),
  528. ],
  529. axis=0,
  530. ).reshape(1, 3)
  531. weight = concat(
  532. [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
  533. axis=0,
  534. ).reshape(1, 3, 3)
  535. weight = broadcast_to(weight, (inp.shape[0], 3, 3))
  536. ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
  537. if mode == "linear":
  538. ret = reshape(ret, ret.shape[0:3])
  539. return ret
  540. def nvof(src: Tensor, precision: int = 1) -> Tensor:
  541. r"""Implements NVIDIA Optical Flow SDK.
  542. Args:
  543. src: input tensor with shape (n, t, h, w, c4) and unit8 dtype.
  544. precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST.
  545. Returns:
  546. output tensor with shape: ``(n, t-1, (h+out_grid_size-1)//out_grid_size, (w+out_grid_size-1)//out_grid_size, c2)``.
  547. By default, out_grid_size = 4. dtype: int16.
  548. .. code-block:: python
  549. import numpy as np
  550. from megengine import tensor
  551. import megengine.functional as F
  552. x = np.random.random_integers(0, 255, (1,2,224,244,4)).astype("uint8")
  553. src = tensor(x)
  554. result = F.nn.nvof(src, precision=1)
  555. print(result.numpy())
  556. """
  557. assert src.ndim == 5 and src.shape[4] == 4
  558. src = src.detach()
  559. op = builtin.NvOf(precision=precision)
  560. return apply(op, src)[0]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台