You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vision.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Iterable, Optional, Tuple, Union
  10. import numpy as np
  11. from ..core import _config
  12. from ..core._imperative_rt.core2 import apply
  13. from ..core.ops import builtin
  14. from ..core.tensor import megbrain_graph, utils
  15. from ..core.tensor.utils import astensor1d
  16. from ..tensor import Tensor
  17. from .elemwise import floor
  18. from .math import argsort
  19. from .tensor import broadcast_to, concat, expand_dims, reshape, transpose
  20. __all__ = [
  21. "correlation",
  22. "cvt_color",
  23. "roi_pooling",
  24. "roi_align",
  25. "nms",
  26. "remap",
  27. "warp_affine",
  28. "warp_perspective",
  29. "interpolate",
  30. "nvof",
  31. ]
  32. def cvt_color(inp: Tensor, mode: str = ""):
  33. r"""Convert images from one format to another
  34. Args:
  35. inp: input images.
  36. mode: format mode.
  37. Returns:
  38. convert result.
  39. Note:
  40. There are different supported modes for different combinations of :attr:`~.Tensor.device` and :attr:`~.Tensor.dtype`.
  41. x86/ARM:
  42. float32:
  43. "RGB2GRAY", "RGB2YUV", "YUV2RGB", "GRAY2RGB", "BGR2GRAY"
  44. uint8:
  45. "RGB2GRAY", "RGB2YUV", "YUV2RGB", "GRAY2RGB", "RGBA2RGB", "RGBA2BGR",
  46. "RGBA2GRAY", "RGB2BGR", "BGR2GRAY", "BGR2RGB", "YUV2GRAY_NV21", "YUV2RGB_NV21",
  47. "YUV2BGR_NV21", "YUV2GRAY_NV12", "YUV2RGB_NV12", "YUV2BGR_NV12", "YUV2GRAY_YV12",
  48. "YUV2RGB_YV12", "YUV2BGR_YV12", "YUV2GRAY_YU12", "YUV2RGB_YU12", "YUV2BGR_YU12",
  49. "YCrCb2RGB", "YCrCb2BGR", "BT601_YUV2RGB_NV21", "BT601_YUV2BGR_NV21", "BT601_YUV2RGB_NV12",
  50. "BT601_YUV2BGR_NV12", "BT601_YUV2RGB_YV12", "BT601_YUV2BGR_YV12" ,"BT601_YUV2RGB_YU12",
  51. "BT601_YUV2BGR_YU12"
  52. CUDA:
  53. float32:
  54. "RGB2GRAY", "BGR2GRAY", "RGB2YUV", "YUV2RGB", "GRAY2RGB"
  55. uint8:
  56. "RGB2GRAY", "BGR2GRAY", "RGB2YUV", "YUV2RGB", "GRAY2RGB",
  57. "YUV2GRAY_NV12", "YUV2GRAY_NV21", "YUV2GRAY_YU12"
  58. "YUV2GRAY_YV12", "YUV2RGB_NV12", "YUV2RGB_NV21", "YUV2BGR_NV12"
  59. "YUV2BGR_NV21", "YUV2RGB_YU12", "YUV2RGB_YV12", "YUV2BGR_YU12",
  60. "YUV2BGR_YV12"
  61. Examples:
  62. .. testcode::
  63. import numpy as np
  64. import megengine as mge
  65. import megengine.functional as F
  66. x = mge.tensor(np.array([[[[-0.58675045, 1.7526233, 0.10702174]]]]).astype(np.float32))
  67. y = F.vision.cvt_color(x, mode="RGB2GRAY")
  68. print(y.numpy())
  69. Outputs:
  70. .. testoutput::
  71. [[[[0.86555195]]]]
  72. """
  73. mode = mode.upper() if "YCrCb" not in mode else mode
  74. assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color"
  75. mode = getattr(builtin.CvtColor.Mode, mode)
  76. assert isinstance(mode, builtin.CvtColor.Mode)
  77. op = builtin.CvtColor(mode=mode)
  78. (out,) = apply(op, inp)
  79. return out
  80. def roi_pooling(
  81. inp: Tensor,
  82. rois: Tensor,
  83. output_shape: Union[int, tuple, list],
  84. mode: str = "max",
  85. scale: float = 1.0,
  86. ) -> Tensor:
  87. r"""Applies roi pooling on input feature.
  88. Args:
  89. inp: tensor that represents the input feature, `(N, C, H, W)` images.
  90. rois: K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy.
  91. output_shape: height, width)` of output rois feature.
  92. mode: max" or "average", use max/average align just like max/average pooling. Default: "max"
  93. scale: scale the input boxes by this number. Default: 1.0
  94. Returns:
  95. ``K, C, output_shape[0], output_shape[1])`` feature of rois.
  96. Examples:
  97. .. testcode::
  98. import numpy as np
  99. from megengine import tensor
  100. import megengine.functional as F
  101. np.random.seed(42)
  102. inp = tensor(np.random.randn(1, 1, 128, 128))
  103. rois = tensor(np.random.random((4, 5)))
  104. y = F.vision.roi_pooling(inp, rois, (2, 2))
  105. print(y.numpy()[0].round(decimals=4))
  106. Outputs:
  107. .. testoutput::
  108. [[[-0.1383 -0.1383]
  109. [-0.5035 -0.5035]]]
  110. """
  111. assert mode.lower() in ["max", "average"], "only max/average mode is supported"
  112. if isinstance(output_shape, int):
  113. output_shape = (output_shape, output_shape)
  114. op = builtin.ROIPooling(mode=mode, scale=scale)
  115. result, _ = apply(
  116. op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device)
  117. )
  118. return result
  119. def correlation(
  120. data1: Tensor,
  121. data2: Tensor,
  122. kernel_size: int = 1,
  123. max_displacement: int = 1,
  124. stride1: int = 1,
  125. stride2: int = 1,
  126. pad_size: int = 0,
  127. is_multiply: bool = True,
  128. ) -> Tensor:
  129. r"""Applies correlation to inputs.
  130. Args:
  131. data1: Input data1 to the correlation. format must be nchw
  132. data2: Input data2 to the correlation. format must be nchw
  133. kernel_size: int (non-negative), optional, default=1) – kernel size for Correlation must be an odd number
  134. max_displacement: int (non-negative), optional, default=1) – Max displacement of Correlation
  135. stride1: int (non-negative), optional, default=1) – stride1 quantize data1 globally
  136. stride2: int (non-negative), optional, default=1) – stride2 quantize data2 within the neighborhood centered around data1
  137. pad_size: int (non-negative), optional, default=0) – pad for Correlation
  138. is_multiply: boolean, optional, default=True) – operation type is either multiplication or absolute difference
  139. """
  140. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  141. assert conv_format == "NCHW", "Currently correlation only support NCHW mode"
  142. op = builtin.Correlation(
  143. format=conv_format,
  144. kernel_size=kernel_size,
  145. max_displacement=max_displacement,
  146. stride1=stride1,
  147. stride2=stride2,
  148. pad_size=pad_size,
  149. is_multiply=is_multiply,
  150. )
  151. result, *_ = apply(op, data1, data2)
  152. return result
  153. def roi_align(
  154. inp: Tensor,
  155. rois: Tensor,
  156. output_shape: Union[int, tuple, list],
  157. mode: str = "average",
  158. spatial_scale: float = 1.0,
  159. sample_points: Union[int, tuple, list] = 2,
  160. aligned: bool = True,
  161. ) -> Tensor:
  162. r"""Applies roi align on input feature.
  163. Args:
  164. inp: tensor that represents the input feature, shape is `(N, C, H, W)`.
  165. rois: N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``.
  166. output_shape: height, width)` shape of output rois feature.
  167. mode: max" or "average", use max/average align just like max/average pooling. Default: "average"
  168. spatial_scale: scale the input boxes by this number. Default: 1.0
  169. sample_points: number of inputs samples to take for each output sample.
  170. 0 to take samples densely. Default: 2
  171. aligned: wheather to align the input feature, with `aligned=True`,
  172. we first appropriately scale the ROI and then shift it by -0.5. Default: True
  173. Returns:
  174. output tensor.
  175. Examples:
  176. .. testcode::
  177. import numpy as np
  178. from megengine import tensor
  179. import megengine.functional as F
  180. np.random.seed(42)
  181. inp = tensor(np.random.randn(1, 1, 128, 128))
  182. rois = tensor(np.random.random((4, 5)))
  183. y = F.vision.roi_align(inp, rois, (2, 2))
  184. print(y.numpy()[0].round(decimals=4))
  185. Outputs:
  186. .. testoutput::
  187. [[[0.175 0.175 ]
  188. [0.1359 0.1359]]]
  189. """
  190. if inp.dtype != np.float32:
  191. inp = inp.astype(np.float32)
  192. mode = mode.lower()
  193. assert mode in ["max", "average"], "only max/average mode is supported"
  194. if isinstance(output_shape, int):
  195. output_shape = (output_shape, output_shape)
  196. pooled_height, pooled_width = output_shape
  197. if isinstance(sample_points, int):
  198. sample_points = (sample_points, sample_points)
  199. sample_height, sample_width = sample_points
  200. offset = 0.5 if aligned else 0.0
  201. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  202. assert conv_format == "NCHW", "Currently roi_align only support NCHW mode"
  203. op = builtin.ROIAlign(
  204. mode=mode,
  205. format=conv_format,
  206. spatial_scale=spatial_scale,
  207. offset=offset,
  208. pooled_height=pooled_height,
  209. pooled_width=pooled_width,
  210. sample_height=sample_height,
  211. sample_width=sample_width,
  212. )
  213. result, *_ = apply(op, inp, rois)
  214. return result
  215. def nms(
  216. boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None
  217. ) -> Tensor:
  218. r"""Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU).
  219. Args:
  220. boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
  221. iou_thresh: IoU threshold for overlapping.
  222. scores: tensor of shape `(N,)`, the score of boxes.
  223. max_output: the maximum number of boxes to keep; it is optional if this operator is not traced
  224. otherwise it required to be specified; if it is not specified, all boxes are kept.
  225. Returns:
  226. indices of the elements that have been kept by NMS, sorted by scores.
  227. Note:
  228. max_output should be specified and should have valid positive value under tracing.
  229. Examples:
  230. .. testcode::
  231. import numpy as np
  232. from megengine import tensor
  233. import megengine.functional as F
  234. x = np.zeros((100,4))
  235. np.random.seed(42)
  236. x[:,:2] = np.random.rand(100,2)*20
  237. x[:,2:] = np.random.rand(100,2)*20 + 100
  238. scores = tensor(np.random.rand(100))
  239. inp = tensor(x)
  240. result = F.vision.nms(inp, scores, iou_thresh=0.7)
  241. print(result.numpy())
  242. Outputs:
  243. .. testoutput::
  244. [75 69]
  245. """
  246. assert (
  247. boxes.ndim == 2 and boxes.shape[1] == 4
  248. ), "the expected shape of boxes is (N, 4)"
  249. assert scores.ndim == 1, "the expected shape of scores is (N,)"
  250. assert (
  251. boxes.shape[0] == scores.shape[0]
  252. ), "number of boxes and scores are not matched"
  253. boxes = boxes.detach()
  254. scores = scores.detach()
  255. sorted_idx = argsort(scores, descending=True)
  256. boxes = boxes[sorted_idx]
  257. if max_output is None:
  258. max_output = boxes.shape[0]
  259. op = builtin.NMSKeep(iou_thresh, max_output)
  260. inp = (boxes.reshape(1, -1, 4),)
  261. indices, count = apply(op, *inp)
  262. indices = indices[0][: count[0]]
  263. keep_inds = sorted_idx[indices]
  264. return keep_inds
  265. def remap(
  266. inp: Tensor,
  267. map_xy: Tensor,
  268. border_mode: str = "replicate",
  269. scalar: float = 0.0,
  270. interp_mode: str = "linear",
  271. ) -> Tensor:
  272. r"""Applies remap transformation to batched 2D images.
  273. The input images are transformed to the output images by the tensor map_xy.
  274. The output's H and W are same as map_xy's H and W.
  275. Args:
  276. inp: input image
  277. map_xy: batch, oh, ow, 2) transformation matrix
  278. border_mode: pixel extrapolation method.
  279. Default: "replicate". Currently also support "constant", "reflect",
  280. "reflect_101", "wrap".
  281. scalar: value used in case of a constant border. Default: 0
  282. interp_mode: interpolation methods.
  283. Default: "linear". Currently only support "linear" mode.
  284. Returns:
  285. output tensor.
  286. Examples:
  287. .. testcode::
  288. import numpy as np
  289. from megengine import tensor
  290. import megengine.functional as F
  291. inp_shape = (1, 1, 4, 4)
  292. inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  293. map_xy_shape = (1, 2, 2, 2)
  294. map_xy = tensor(np.array([[[1., 0.],[0., 1.]],
  295. [[0., 1.],[0., 1.]]],
  296. dtype=np.float32).reshape(map_xy_shape))
  297. out = F.vision.remap(inp, map_xy)
  298. print(out.numpy())
  299. Outputs:
  300. .. testoutput::
  301. [[[[1. 4.]
  302. [4. 4.]]]]
  303. """
  304. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  305. op = builtin.Remap(
  306. imode=interp_mode, border_type=border_mode, format=conv_format, scalar=scalar
  307. )
  308. assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
  309. (result,) = apply(op, inp, map_xy)
  310. return result
  311. def warp_affine(
  312. inp: Tensor,
  313. mat: Tensor,
  314. out_shape: Union[Tuple[int, int], int, Tensor],
  315. border_mode: str = "replicate",
  316. border_val: float = 0.0,
  317. format: str = "NHWC",
  318. interp_mode: str = "linear",
  319. ) -> Tensor:
  320. r"""Batched affine transform on 2D images.
  321. Args:
  322. inp: input image.
  323. mat: batch, 2, 3)` transformation matrix.
  324. out_shape: output tensor shape.
  325. border_mode: pixel extrapolation method.
  326. Default: "wrap". Currently "constant", "reflect",
  327. "reflect_101", "isolated", "wrap", "replicate", "transparent" are supported.
  328. border_val: value used in case of a constant border. Default: 0
  329. format: NHWC" as default based on historical concerns,
  330. "NCHW" is also supported. Default: "NHWC".
  331. interp_mode: interpolation methods. Could be "linear", "nearest", "cubic", "area".
  332. Default: "linear".
  333. Returns:
  334. output tensor.
  335. Note:
  336. Here all available options for params are listed,
  337. however it does not mean that you can use all the combinations.
  338. On different platforms, different combinations are supported.
  339. ``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed.
  340. """
  341. conv_format = _config._get_actual_op_param(format, _config.__conv_format)
  342. op = builtin.WarpAffine(
  343. border_mode=border_mode,
  344. border_val=border_val,
  345. format=conv_format,
  346. imode=interp_mode,
  347. )
  348. out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device)
  349. (result,) = apply(op, inp, mat, out_shape)
  350. return result
  351. def warp_perspective(
  352. inp: Tensor,
  353. mat: Tensor,
  354. out_shape: Union[Tuple[int, int], int, Tensor],
  355. mat_idx: Optional[Union[Iterable[int], Tensor]] = None,
  356. border_mode: str = "replicate",
  357. border_val: float = 0.0,
  358. format: str = "NCHW",
  359. interp_mode: str = "linear",
  360. ) -> Tensor:
  361. r"""Applies perspective transformation to batched 2D images.
  362. The input images are transformed to the output images by the transformation matrix:
  363. .. math::
  364. \text{output}(n, c, h, w) = \text{input} \left( n, c,
  365. \frac{M_{00}w + M_{01}h + M_{02}}{M_{20}w + M_{21}h + M_{22}},
  366. \frac{M_{10}w + M_{11}h + M_{12}}{M_{20}w + M_{21}h + M_{22}}
  367. \right)
  368. Optionally, we can set `mat_idx` to assign different transformations to the same image,
  369. otherwise the input images and transformations should be one-to-one correnspondence.
  370. Args:
  371. inp: input image.
  372. mat: batch, 3, 3)` transformation matrix.
  373. out_shape: h, w)` size of the output image.
  374. mat_idx: batch, )` image batch idx assigned to each matrix. Default: None
  375. border_mode: pixel extrapolation method.
  376. Default: "replicate". Currently also support "constant", "reflect",
  377. "reflect_101", "wrap".
  378. border_val: value used in case of a constant border. Default: 0
  379. format: NHWC" is also supported. Default: "NCHW".
  380. interp_mode: interpolation methods.
  381. Default: "linear". Currently only support "linear" mode.
  382. Returns:
  383. output tensor.
  384. Note:
  385. The transformation matrix is the inverse of that used by `cv2.warpPerspective`.
  386. Examples:
  387. .. testcode::
  388. import numpy as np
  389. from megengine import tensor
  390. import megengine.functional as F
  391. inp_shape = (1, 1, 4, 4)
  392. x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  393. M_shape = (1, 3, 3)
  394. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  395. M = tensor(np.array([[1., 0., 1.],
  396. [0., 1., 1.],
  397. [0., 0., 1.]], dtype=np.float32).reshape(M_shape))
  398. out = F.vision.warp_perspective(x, M, (2, 2))
  399. print(out.numpy())
  400. Outputs:
  401. .. testoutput::
  402. [[[[ 5. 6.]
  403. [ 9. 10.]]]]
  404. """
  405. if inp.dtype == np.float32:
  406. mat = mat.astype("float32")
  407. if inp.dtype == np.float16:
  408. inp = inp.astype("float32")
  409. conv_format = _config._get_actual_op_param(format, _config.__conv_format)
  410. op = builtin.WarpPerspective(
  411. imode=interp_mode, bmode=border_mode, format=conv_format, border_val=border_val
  412. )
  413. out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device)
  414. if mat_idx is not None:
  415. mat_idx = astensor1d(mat_idx, inp, dtype="int32", device=inp.device)
  416. (result,) = apply(op, inp, mat, mat_idx, out_shape)
  417. return result
  418. (result,) = apply(op, inp, mat, out_shape)
  419. return result
  420. def interpolate(
  421. inp: Tensor,
  422. size: Optional[Union[int, Tuple[int, int]]] = None,
  423. scale_factor: Optional[Union[float, Tuple[float, float]]] = None,
  424. mode: str = "bilinear",
  425. align_corners: Optional[bool] = None,
  426. ) -> Tensor:
  427. r"""Down/up samples the input tensor to either the given size or with the given scale_factor. ``size`` can not coexist with ``scale_factor``.
  428. Args:
  429. inp: input tensor.
  430. size: size of the output tensor. Default: None
  431. scale_factor: scaling factor of the output tensor. Default: None
  432. mode: interpolation methods, acceptable values are:
  433. "bilinear", "linear", "bicubic" and "nearest". Default: "bilinear"
  434. align_corners: This only has an effect when `mode`
  435. is "bilinear" or "linear". Geometrically, we consider the pixels of the input
  436. and output as squares rather than points. If set to ``True``, the input
  437. and output tensors are aligned by the center points of their corner
  438. pixels, preserving the values at the corner pixels. If set to ``False``,
  439. the input and output tensors are aligned by the corner points of their
  440. corner pixels, and the interpolation uses edge value padding for
  441. out-of-boundary values, making this operation *independent* of input size
  442. Returns:
  443. output tensor.
  444. Examples:
  445. .. testcode::
  446. import numpy as np
  447. from megengine import tensor
  448. import megengine.functional as F
  449. x = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  450. out = F.vision.interpolate(x, [4, 4], align_corners=False)
  451. print(out.numpy())
  452. out2 = F.vision.interpolate(x, scale_factor=2.)
  453. np.testing.assert_allclose(out.numpy(), out2.numpy())
  454. Outputs:
  455. .. testoutput::
  456. [[[[1. 1.25 1.75 2. ]
  457. [1.5 1.75 2.25 2.5 ]
  458. [2.5 2.75 3.25 3.5 ]
  459. [3. 3.25 3.75 4. ]]]]
  460. """
  461. mode = mode.lower()
  462. if mode not in ["bilinear", "linear", "bicubic", "nearest"]:
  463. raise ValueError("unsupported interpolate mode: {}".format(mode))
  464. if mode not in ["bilinear", "linear"]:
  465. if align_corners is not None:
  466. raise ValueError(
  467. "align_corners option can only be set in the bilinear/linear interpolating mode"
  468. )
  469. else:
  470. if align_corners is None:
  471. align_corners = False
  472. if mode == "linear":
  473. inp = expand_dims(inp, 3)
  474. if inp.ndim != 4:
  475. raise ValueError("shape of input tensor must correspond to the operartion mode")
  476. def get_dsize(scale_factor):
  477. if isinstance(scale_factor, (float, int)):
  478. scale_factor = float(scale_factor)
  479. if mode == "linear":
  480. scale_factor = (scale_factor, float(1))
  481. else:
  482. scale_factor = (scale_factor, scale_factor)
  483. else:
  484. if mode == "linear":
  485. raise ValueError(
  486. "under linear mode, scale_factor can only be single value"
  487. )
  488. assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )"
  489. assert isinstance(scale_factor[0], float) and isinstance(
  490. scale_factor[1], float
  491. ), "scale_factor must be float type"
  492. dsize = tuple(
  493. floor(
  494. Tensor(
  495. inp.shape[i + 2] * scale_factor[i],
  496. dtype="float32",
  497. device=inp.device,
  498. )
  499. )
  500. for i in range(2)
  501. )
  502. dsize = concat([dsize[0], dsize[1]], axis=0)
  503. return dsize
  504. if size is None:
  505. if scale_factor is None:
  506. raise ValueError("scale_factor must not be None when size is None")
  507. dsize = get_dsize(scale_factor)
  508. else:
  509. if scale_factor is not None:
  510. raise ValueError("scale_factor must be None when size is provided")
  511. if isinstance(size, int):
  512. size = (size, 1)
  513. else:
  514. if mode == "linear":
  515. raise ValueError("under linear mode, size can only be single value")
  516. dsize = size
  517. if not align_corners:
  518. # fastpath for interpolate
  519. mode_map = {
  520. "linear": "linear",
  521. "bilinear": "linear",
  522. "nearest": "nearest",
  523. "bicubic": "cubic",
  524. }
  525. if inp.dtype == np.float16:
  526. inp = inp.astype("float32")
  527. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  528. assert conv_format == "NCHW", "Currently resize only support NCHW mode"
  529. op = builtin.Resize(imode=mode_map[mode], format=conv_format)
  530. shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
  531. (ret,) = apply(op, inp, shape)
  532. else:
  533. assert mode in [
  534. "linear",
  535. "bilinear",
  536. ], "align_corners only support linear or bilinear mode"
  537. oh, ow = dsize[0], dsize[1]
  538. ih, iw = inp.shape[2], inp.shape[3]
  539. hscale = (ih - 1.0) / (oh - 1.0)
  540. wscale = 1.0 * iw / ow
  541. if mode != "linear":
  542. wscale = (iw - 1.0) / (ow - 1.0)
  543. row0 = concat(
  544. [wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0
  545. ).reshape(1, 3)
  546. row1 = concat(
  547. [
  548. Tensor(0, dtype="float32", device=inp.device),
  549. hscale,
  550. Tensor(0, dtype="float32", device=inp.device),
  551. ],
  552. axis=0,
  553. ).reshape(1, 3)
  554. weight = concat(
  555. [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
  556. axis=0,
  557. ).reshape(1, 3, 3)
  558. weight = broadcast_to(weight, (inp.shape[0], 3, 3))
  559. ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
  560. if mode == "linear":
  561. ret = reshape(ret, ret.shape[0:3])
  562. return ret
  563. def nvof(src: Tensor, precision: int = 1) -> Tensor:
  564. r"""Implements NVIDIA Optical Flow SDK.
  565. Args:
  566. src: input tensor with shape (n, t, h, w, c4) and unit8 dtype.
  567. precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST.
  568. Returns:
  569. output tensor with shape: ``(n, t-1, (h+out_grid_size-1)//out_grid_size, (w+out_grid_size-1)//out_grid_size, c2)``.
  570. By default, out_grid_size = 4. dtype: int16.
  571. .. code-block:: python
  572. import numpy as np
  573. from megengine import tensor
  574. import megengine.functional as F
  575. x = np.random.random_integers(0, 255, (1,2,224,244,4)).astype("uint8")
  576. src = tensor(x)
  577. result = F.nn.nvof(src, precision=1)
  578. print(result.numpy())
  579. """
  580. assert src.ndim == 5 and src.shape[4] == 4
  581. src = src.detach()
  582. op = builtin.NvOf(precision=precision)
  583. return apply(op, src)[0]