You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vision.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. # -*- coding: utf-8 -*-
  2. from typing import Iterable, Optional, Tuple, Union
  3. import numpy as np
  4. from ..core import _config
  5. from ..core._imperative_rt.core2 import apply
  6. from ..core.ops import builtin
  7. from ..core.tensor import megbrain_graph, utils
  8. from ..core.tensor.utils import astensor1d
  9. from ..tensor import Tensor
  10. from .elemwise import floor
  11. from .math import argsort
  12. from .tensor import broadcast_to, concat, expand_dims, reshape, transpose
  13. __all__ = [
  14. "correlation",
  15. "cvt_color",
  16. "interpolate",
  17. "nms",
  18. "nvof",
  19. "remap",
  20. "roi_align",
  21. "roi_pooling",
  22. "warp_affine",
  23. "warp_perspective",
  24. ]
  25. def cvt_color(inp: Tensor, mode: str = ""):
  26. r"""Convert images from one format to another
  27. Args:
  28. inp: input images.
  29. mode: format mode.
  30. Returns:
  31. convert result.
  32. Note:
  33. There are different supported modes for different combinations of :attr:`~.Tensor.device` and :attr:`~.Tensor.dtype`.
  34. x86/ARM:
  35. float32:
  36. "RGB2GRAY", "RGB2YUV", "YUV2RGB", "GRAY2RGB", "BGR2GRAY"
  37. uint8:
  38. "RGB2GRAY", "RGB2YUV", "YUV2RGB", "GRAY2RGB", "RGBA2RGB", "RGBA2BGR",
  39. "RGBA2GRAY", "RGB2BGR", "BGR2GRAY", "BGR2RGB", "YUV2GRAY_NV21", "YUV2RGB_NV21",
  40. "YUV2BGR_NV21", "YUV2GRAY_NV12", "YUV2RGB_NV12", "YUV2BGR_NV12", "YUV2GRAY_YV12",
  41. "YUV2RGB_YV12", "YUV2BGR_YV12", "YUV2GRAY_YU12", "YUV2RGB_YU12", "YUV2BGR_YU12",
  42. "YCrCb2RGB", "YCrCb2BGR", "BT601_YUV2RGB_NV21", "BT601_YUV2BGR_NV21", "BT601_YUV2RGB_NV12",
  43. "BT601_YUV2BGR_NV12", "BT601_YUV2RGB_YV12", "BT601_YUV2BGR_YV12" ,"BT601_YUV2RGB_YU12",
  44. "BT601_YUV2BGR_YU12"
  45. CUDA:
  46. float32:
  47. "RGB2GRAY", "BGR2GRAY", "RGB2YUV", "YUV2RGB", "GRAY2RGB"
  48. uint8:
  49. "RGB2GRAY", "BGR2GRAY", "RGB2YUV", "YUV2RGB", "GRAY2RGB",
  50. "YUV2GRAY_NV12", "YUV2GRAY_NV21", "YUV2GRAY_YU12"
  51. "YUV2GRAY_YV12", "YUV2RGB_NV12", "YUV2RGB_NV21", "YUV2BGR_NV12"
  52. "YUV2BGR_NV21", "YUV2RGB_YU12", "YUV2RGB_YV12", "YUV2BGR_YU12",
  53. "YUV2BGR_YV12"
  54. Examples:
  55. >>> import numpy as np
  56. >>> x = mge.tensor(np.array([[[[-0.58675045, 1.7526233, 0.10702174]]]]).astype(np.float32))
  57. >>> y = F.vision.cvt_color(x, mode="RGB2GRAY")
  58. >>> y.numpy()
  59. array([[[[0.86555195]]]], dtype=float32)
  60. """
  61. mode = mode.upper() if "YCrCb" not in mode else mode
  62. assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color"
  63. mode = getattr(builtin.CvtColor.Mode, mode)
  64. assert isinstance(mode, builtin.CvtColor.Mode)
  65. op = builtin.CvtColor(mode=mode)
  66. (out,) = apply(op, inp)
  67. return out
  68. def roi_pooling(
  69. inp: Tensor,
  70. rois: Tensor,
  71. output_shape: Union[int, tuple, list],
  72. mode: str = "max",
  73. scale: float = 1.0,
  74. ) -> Tensor:
  75. r"""Applies RoI (Region of Interest) pooling on input feature, as described in Faster RCNN.
  76. .. seealso::
  77. * `Region of interest pooling explained <https://deepsense.ai/region-of-interest-pooling-explained/>`_
  78. * `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_
  79. Args:
  80. inp: the input tensor that represents the input feature with ``(n, c, h, w)`` shape.
  81. rois: a tensor represents Regions of Interest with shape ``(K, 5)``, which means total ``K`` box coordinates in ``(idx, x1, y1, x2, y2)`` format where the regions will be taken from.
  82. The coordinate including ``(x1, y1)`` and ``(x2, y2)`` must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  83. The first column ``idx`` should contain the index of the corresponding element in the input batch, i.e. a number in ``[0, n - 1]``.
  84. mode: "max" or "average", the pooling mode to be used. Default: "max"
  85. scale: It is a scale that maps output rois feature to input feature. For example, if the output is 224 * 224 image, and the input is a 112 * 112 feature map, then the scale should be set to 0.5. The default value is 1.0
  86. Returns:
  87. output tensor. ``(K, C, output_shape[0], output_shape[1])`` feature of rois.
  88. Examples:
  89. >>> import numpy as np
  90. >>> np.random.seed(42)
  91. >>> inp = Tensor(np.random.randn(1, 1, 128, 128))
  92. >>> rois = Tensor(np.random.random((4, 5)))
  93. >>> y = F.vision.roi_pooling(inp, rois, (2, 2))
  94. >>> y.numpy()[0].round(decimals=4)
  95. array([[[-0.1383, -0.1383],
  96. [-0.5035, -0.5035]]], dtype=float32)
  97. """
  98. assert mode.lower() in ["max", "average"], "only max/average mode is supported"
  99. if isinstance(output_shape, int):
  100. output_shape = (output_shape, output_shape)
  101. op = builtin.ROIPooling(mode=mode, scale=scale)
  102. result, _ = apply(
  103. op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device)
  104. )
  105. return result
  106. def correlation(
  107. data1: Tensor,
  108. data2: Tensor,
  109. kernel_size: int = 1,
  110. max_displacement: int = 1,
  111. stride1: int = 1,
  112. stride2: int = 1,
  113. pad_size: int = 0,
  114. is_multiply: bool = True,
  115. ) -> Tensor:
  116. r"""Applies correlation to inputs.
  117. Args:
  118. data1: Input data1 to the correlation. format must be nchw
  119. data2: Input data2 to the correlation. format must be nchw
  120. kernel_size: int (non-negative), optional, default=1) – kernel size for Correlation must be an odd number
  121. max_displacement: int (non-negative), optional, default=1) – Max displacement of Correlation
  122. stride1: int (non-negative), optional, default=1) – stride1 quantize data1 globally
  123. stride2: int (non-negative), optional, default=1) – stride2 quantize data2 within the neighborhood centered around data1
  124. pad_size: int (non-negative), optional, default=0) – pad for Correlation
  125. is_multiply: boolean, optional, default=True) – operation type is either multiplication or absolute difference
  126. """
  127. # Currently correlation only support NCHW mode
  128. format = "NCHW"
  129. op = builtin.Correlation(
  130. format=format,
  131. kernel_size=kernel_size,
  132. max_displacement=max_displacement,
  133. stride1=stride1,
  134. stride2=stride2,
  135. pad_size=pad_size,
  136. is_multiply=is_multiply,
  137. )
  138. result, *_ = apply(op, data1, data2)
  139. return result
  140. def roi_align(
  141. inp: Tensor,
  142. rois: Tensor,
  143. output_shape: Union[int, tuple, list],
  144. mode: str = "average",
  145. spatial_scale: float = 1.0,
  146. sample_points: Union[int, tuple, list] = 2,
  147. aligned: bool = True,
  148. ) -> Tensor:
  149. r"""Applies RoI (Region of Interest) align on input feature, as described in Mask R-CNN.
  150. .. seealso::
  151. * `RoIAlign <https://paperswithcode.com/method/roi-align>`_
  152. * `Mask R-CNN <https://arxiv.org/abs/1703.06870v3>`_
  153. Args:
  154. inp: the input tensor that represents the input feature with ``(n, c, h, w)`` shape.
  155. rois: a tensor represents Regions of Interest with shape ``(K, 5)``, which means total ``K`` box coordinates in ``(idx, x1, y1, x2, y2)`` format where the regions will be taken from.
  156. The coordinate including ``(x1, y1)`` and ``(x2, y2)`` must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  157. The first column ``idx`` should contain the index of the corresponding element in the input batch, i.e. a number in ``[0, n - 1]``.
  158. output_shape: ``(height, width)`` shape of output rois feature.
  159. mode: "max" or "average", use max/average align just like max/average pooling. Default: "average"
  160. spatial_scale: scale the input boxes by this number. Default: 1.0
  161. sample_points: number of inputs samples to take for each output sample.
  162. 0 to take samples densely. Default: 2
  163. aligned: wheather to align the input feature, with ``aligned=True``,
  164. we first appropriately scale the ROI and then shift it by -0.5. Default: True
  165. Returns:
  166. output tensor.
  167. Examples:
  168. >>> import numpy as np
  169. >>> np.random.seed(42)
  170. >>> inp = Tensor(np.random.randn(1, 1, 128, 128))
  171. >>> rois = Tensor(np.random.random((4, 5)))
  172. >>> y = F.vision.roi_align(inp, rois, (2, 2))
  173. >>> y.numpy()[0].round(decimals=4)
  174. array([[[0.175 , 0.175 ],
  175. [0.1359, 0.1359]]], dtype=float32)
  176. """
  177. if inp.dtype != np.float32:
  178. inp = inp.astype(np.float32)
  179. mode = mode.lower()
  180. assert mode in ["max", "average"], "only max/average mode is supported"
  181. if isinstance(output_shape, int):
  182. output_shape = (output_shape, output_shape)
  183. pooled_height, pooled_width = output_shape
  184. if isinstance(sample_points, int):
  185. sample_points = (sample_points, sample_points)
  186. sample_height, sample_width = sample_points
  187. offset = 0.5 if aligned else 0.0
  188. # Currently roi_align only support NCHW mode
  189. format = "NCHW"
  190. op = builtin.ROIAlign(
  191. mode=mode,
  192. format=format,
  193. spatial_scale=spatial_scale,
  194. offset=offset,
  195. pooled_height=pooled_height,
  196. pooled_width=pooled_width,
  197. sample_height=sample_height,
  198. sample_width=sample_width,
  199. )
  200. result, *_ = apply(op, inp, rois)
  201. return result
  202. def nms(
  203. boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None
  204. ) -> Tensor:
  205. r"""Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU).
  206. Args:
  207. boxes: tensor of shape ``(N, 4)``; the boxes to perform nms on; each box is expected to be in ``(x1, y1, x2, y2)`` format.
  208. iou_thresh: IoU threshold for overlapping.
  209. scores: tensor of shape ``(N,)``, the score of boxes.
  210. max_output: the maximum number of boxes to keep; it is optional if this operator is not traced
  211. otherwise it required to be specified; if it is not specified, all boxes are kept.
  212. Returns:
  213. indices of the elements that have been kept by NMS, sorted by scores.
  214. Note:
  215. max_output should be specified and should have valid positive value under tracing.
  216. Examples:
  217. >>> import numpy as np
  218. >>> x = np.zeros((100,4))
  219. >>> np.random.seed(42)
  220. >>> x[:,:2] = np.random.rand(100,2)*20
  221. >>> x[:,2:] = np.random.rand(100,2)*20 + 100
  222. >>> scores = Tensor(np.random.rand(100))
  223. >>> inp = Tensor(x)
  224. >>> F.vision.nms(inp, scores, iou_thresh=0.7)
  225. Tensor([75 69], dtype=int32, device=xpux:0)
  226. """
  227. assert (
  228. boxes.ndim == 2 and boxes.shape[1] == 4
  229. ), "the expected shape of boxes is (N, 4)"
  230. assert scores.ndim == 1, "the expected shape of scores is (N,)"
  231. assert (
  232. boxes.shape[0] == scores.shape[0]
  233. ), "number of boxes and scores are not matched"
  234. boxes = boxes.detach()
  235. scores = scores.detach()
  236. sorted_idx = argsort(scores, descending=True)
  237. boxes = boxes[sorted_idx]
  238. if max_output is None:
  239. max_output = boxes.shape[0]
  240. op = builtin.NMSKeep(iou_thresh, max_output)
  241. inp = (boxes.reshape(1, -1, 4),)
  242. indices, count = apply(op, *inp)
  243. indices = indices[0][: count[0]]
  244. keep_inds = sorted_idx[indices]
  245. return keep_inds
  246. def remap(
  247. inp: Tensor,
  248. map_xy: Tensor,
  249. border_mode: str = "replicate",
  250. scalar: float = 0.0,
  251. interp_mode: str = "linear",
  252. ) -> Tensor:
  253. r"""Applies remap transformation to batched 2D images. Remap is an operation that relocates pixels in a image to another location in a new image.
  254. The input images are transformed to the output images by the tensor ``map_xy``.
  255. The output's H and W are same as ``map_xy``'s H and W.
  256. Args:
  257. inp: input image, its shape represents ``[b, c, in_h, in_w]``.
  258. map_xy: transformation matrix, its shape shoule be ``[b, o_h, o_w, 2]``. The shape of output is determined by o_h and o_w.
  259. For each element in output, its value is determined by inp and ``map_xy``.
  260. ``map_xy[..., 0]`` and ``map_xy[..., 1]`` are the positions of
  261. the current element in inp, respectively. Therefore, their ranges are ``[0, in_w - 1]`` and ``[0, in_h - 1]``.
  262. border_mode: pixel extrapolation method. Default: "replicate". Currently also support "constant", "reflect", "reflect_101", "wrap".
  263. "replicate": repeatedly fills the edge pixel values of the duplicate image, expanding the new boundary pixel values with
  264. the edge pixel values.
  265. "constant": fills the edges of the image with a fixed numeric value.
  266. scalar: value used in case of a constant border. Default: 0
  267. interp_mode: interpolation methods. Default: "linear". Currently also support "nearest" mode.
  268. Returns:
  269. output tensor. [b, c, o_h, o_w]
  270. Examples:
  271. >>> import numpy as np
  272. >>> inp_shape = (1, 1, 4, 4)
  273. >>> inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  274. >>> map_xy_shape = (1, 2, 2, 2)
  275. >>> map_xy = Tensor(np.array([[[1., 0.],[0., 1.]],
  276. ... [[0., 1.],[0., 1.]]],
  277. ... dtype=np.float32).reshape(map_xy_shape))
  278. >>> out = F.vision.remap(inp, map_xy)
  279. >>> out.numpy()
  280. array([[[[1., 4.],
  281. [4., 4.]]]], dtype=float32)
  282. """
  283. format = "NCHW"
  284. op = builtin.Remap(
  285. imode=interp_mode, border_type=border_mode, format=format, scalar=scalar
  286. )
  287. assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
  288. (result,) = apply(op, inp, map_xy)
  289. return result
  290. def warp_affine(
  291. inp: Tensor,
  292. mat: Tensor,
  293. out_shape: Union[Tuple[int, int], int, Tensor],
  294. border_mode: str = "replicate",
  295. border_val: float = 0.0,
  296. format: str = "NHWC",
  297. interp_mode: str = "linear",
  298. ) -> Tensor:
  299. r"""Batched affine transformation on 2D images. Affine transformation is a linear transformation between two-dimensional coordinates.
  300. Args:
  301. inp: input image.
  302. mat: `(batch, 2, 3)` transformation matrix.
  303. out_shape: output tensor shape.
  304. border_mode: pixel extrapolation method.
  305. Default: "replicate". Currently "constant", "reflect",
  306. "reflect_101", "isolated", "wrap", "replicate", "transparent" are supported.
  307. border_val: value used in case of a constant border. Default: 0
  308. format: NHWC" as default based on historical concerns,
  309. "NCHW" is also supported. Default: "NHWC".
  310. interp_mode: interpolation methods. Could be "linear", "nearest", "cubic", "area".
  311. Default: "linear".
  312. Returns:
  313. output tensor.
  314. Note:
  315. Here all available options for params are listed,
  316. however it does not mean that you can use all the combinations.
  317. On different platforms, different combinations are supported.
  318. ``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed.
  319. """
  320. op = builtin.WarpAffine(
  321. border_mode=border_mode,
  322. border_val=border_val,
  323. format=format,
  324. imode=interp_mode,
  325. )
  326. out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device)
  327. (result,) = apply(op, inp, mat, out_shape)
  328. return result
  329. def warp_perspective(
  330. inp: Tensor,
  331. mat: Tensor,
  332. out_shape: Union[Tuple[int, int], int, Tensor],
  333. mat_idx: Optional[Union[Iterable[int], Tensor]] = None,
  334. border_mode: str = "replicate",
  335. border_val: float = 0.0,
  336. format: str = "NCHW",
  337. interp_mode: str = "linear",
  338. ) -> Tensor:
  339. r"""Applies perspective transformation to batched 2D images. A perspective transformation is a projection of a image onto a new view plane.
  340. The input images are transformed to the output images by the transformation matrix:
  341. .. math::
  342. \text{output}(n, c, h, w) = \text{input} \left( n, c,
  343. \frac{M_{00}w + M_{01}h + M_{02}}{M_{20}w + M_{21}h + M_{22}},
  344. \frac{M_{10}w + M_{11}h + M_{12}}{M_{20}w + M_{21}h + M_{22}}
  345. \right)
  346. Optionally, we can set ``mat_idx`` to assign different transformations to the same image,
  347. otherwise the input images and transformations should be one-to-one correnspondence.
  348. Args:
  349. inp: input image.
  350. mat: ``(batch, 3, 3)`` transformation matrix.
  351. out_shape: ``(h, w)`` size of the output image.
  352. mat_idx: image batch idx assigned to each matrix. Default: None
  353. border_mode: pixel extrapolation method.
  354. Default: "replicate". Currently also support "constant", "reflect",
  355. "reflect_101", "wrap".
  356. border_val: value used in case of a constant border. Default: 0
  357. format: NHWC" is also supported. Default: "NCHW".
  358. interp_mode: interpolation methods.
  359. Default: "linear". Currently only support "linear" mode.
  360. Returns:
  361. output tensor.
  362. Note:
  363. The transformation matrix is the inverse of that used by ``cv2.warpPerspective``.
  364. Examples:
  365. >>> import numpy as np
  366. >>> inp_shape = (1, 1, 4, 4)
  367. >>> x = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  368. >>> M_shape = (1, 3, 3)
  369. >>> # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  370. >>> M = Tensor(np.array([[1., 0., 1.],
  371. ... [0., 1., 1.],
  372. ... [0., 0., 1.]], dtype=np.float32).reshape(M_shape))
  373. >>> out = F.vision.warp_perspective(x, M, (2, 2))
  374. >>> out.numpy()
  375. array([[[[ 5., 6.],
  376. [ 9., 10.]]]], dtype=float32)
  377. """
  378. if inp.dtype == np.float32:
  379. mat = mat.astype("float32")
  380. if inp.dtype == np.float16:
  381. inp = inp.astype("float32")
  382. op = builtin.WarpPerspective(
  383. imode=interp_mode, bmode=border_mode, format=format, border_val=border_val
  384. )
  385. out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device)
  386. if mat_idx is not None:
  387. mat_idx = astensor1d(mat_idx, inp, dtype="int32", device=inp.device)
  388. (result,) = apply(op, inp, mat, mat_idx, out_shape)
  389. return result
  390. (result,) = apply(op, inp, mat, out_shape)
  391. return result
  392. def interpolate(
  393. inp: Tensor,
  394. size: Optional[Union[int, Tuple[int, int]]] = None,
  395. scale_factor: Optional[Union[float, Tuple[float, float]]] = None,
  396. mode: str = "bilinear",
  397. align_corners: Optional[bool] = None,
  398. ) -> Tensor:
  399. r"""Down/up samples the input tensor to either the given size or with the given scale_factor. ``size`` can not coexist with ``scale_factor``.
  400. Args:
  401. inp: input tensor.
  402. size: the size of the output tensor. Default: None
  403. scale_factor: scaling factor of the output tensor. Default: None
  404. mode: interpolation methods, acceptable values are:
  405. "bilinear", "linear", "bicubic" and "nearest". Default: "bilinear"
  406. align_corners: This only has an effect when ``mode``
  407. is "bilinear" or "linear". Geometrically, we consider the pixels of the input
  408. and output as squares rather than points. If set to ``True``, the input
  409. and output tensors are aligned by the center points of their corner
  410. pixels, preserving the values at the corner pixels. If set to ``False``,
  411. the input and output tensors are aligned by the corner points of their
  412. corner pixels, and the interpolation uses edge value padding for
  413. out-of-boundary values, making this operation *independent* of input size
  414. Returns:
  415. output tensor
  416. Examples:
  417. >>> import numpy as np
  418. >>> x = Tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  419. >>> out = F.vision.interpolate(x, [4, 4], align_corners=False)
  420. >>> out.numpy()
  421. array([[[[1. , 1.25, 1.75, 2. ],
  422. [1.5 , 1.75, 2.25, 2.5 ],
  423. [2.5 , 2.75, 3.25, 3.5 ],
  424. [3. , 3.25, 3.75, 4. ]]]], dtype=float32)
  425. >>> out2 = F.vision.interpolate(x, scale_factor=2.)
  426. >>> np.testing.assert_allclose(out.numpy(), out2.numpy())
  427. """
  428. mode = mode.lower()
  429. if mode not in ["bilinear", "linear", "bicubic", "nearest"]:
  430. raise ValueError("unsupported interpolate mode: {}".format(mode))
  431. if mode not in ["bilinear", "linear"]:
  432. if align_corners is not None:
  433. raise ValueError(
  434. "align_corners option can only be set in the bilinear/linear interpolating mode"
  435. )
  436. else:
  437. if align_corners is None:
  438. align_corners = False
  439. if mode == "linear":
  440. inp = expand_dims(inp, 3)
  441. if inp.ndim != 4:
  442. raise ValueError("shape of input tensor must correspond to the operartion mode")
  443. def get_dsize(scale_factor):
  444. if isinstance(scale_factor, (float, int)):
  445. scale_factor = float(scale_factor)
  446. if mode == "linear":
  447. scale_factor = (scale_factor, float(1))
  448. else:
  449. scale_factor = (scale_factor, scale_factor)
  450. else:
  451. if mode == "linear":
  452. raise ValueError(
  453. "under linear mode, scale_factor can only be single value"
  454. )
  455. assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )"
  456. assert isinstance(scale_factor[0], float) and isinstance(
  457. scale_factor[1], float
  458. ), "scale_factor must be float type"
  459. dsize = tuple(
  460. floor(
  461. Tensor(
  462. inp.shape[i + 2] * scale_factor[i],
  463. dtype="float32",
  464. device=inp.device,
  465. )
  466. )
  467. for i in range(2)
  468. )
  469. dsize = concat([dsize[0], dsize[1]], axis=0)
  470. return dsize
  471. if size is None:
  472. if scale_factor is None:
  473. raise ValueError("scale_factor must not be None when size is None")
  474. dsize = get_dsize(scale_factor)
  475. else:
  476. if scale_factor is not None:
  477. raise ValueError("scale_factor must be None when size is provided")
  478. if isinstance(size, int):
  479. size = (size, 1)
  480. else:
  481. if mode == "linear":
  482. raise ValueError("under linear mode, size can only be single value")
  483. dsize = size
  484. if not align_corners:
  485. # fastpath for interpolate
  486. mode_map = {
  487. "linear": "linear",
  488. "bilinear": "linear",
  489. "nearest": "nearest",
  490. "bicubic": "cubic",
  491. }
  492. if inp.dtype == np.float16:
  493. inp = inp.astype("float32")
  494. # Currently resize only support NCHW mode
  495. format = "NCHW"
  496. op = builtin.Resize(imode=mode_map[mode], format=format)
  497. shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
  498. (ret,) = apply(op, inp, shape)
  499. else:
  500. assert mode in [
  501. "linear",
  502. "bilinear",
  503. ], "align_corners only support linear or bilinear mode"
  504. oh, ow = dsize[0], dsize[1]
  505. ih, iw = inp.shape[2], inp.shape[3]
  506. hscale = (ih - 1.0) / (oh - 1.0)
  507. wscale = 1.0 * iw / ow
  508. if mode != "linear":
  509. wscale = (iw - 1.0) / (ow - 1.0)
  510. row0 = concat(
  511. [
  512. Tensor(wscale, dtype="float32", device=inp.device),
  513. Tensor([0, 0], dtype="float32", device=inp.device),
  514. ],
  515. axis=0,
  516. ).reshape(1, 3)
  517. zeros = Tensor([0], dtype="float32", device=inp.device)
  518. row1 = concat(
  519. [zeros, Tensor(hscale, dtype="float32", device=inp.device), zeros], axis=0,
  520. ).reshape(1, 3)
  521. weight = concat(
  522. [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
  523. axis=0,
  524. ).reshape(1, 3, 3)
  525. weight = broadcast_to(weight, (inp.shape[0], 3, 3))
  526. ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
  527. if mode == "linear":
  528. ret = reshape(ret, ret.shape[0:3])
  529. return ret
  530. def nvof(src: Tensor, precision: int = 1) -> Tensor:
  531. r"""Implements NVIDIA Optical Flow SDK.
  532. Args:
  533. src: input tensor with shape (n, t, h, w, c4) and unit8 dtype.
  534. precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST.
  535. Returns:
  536. output tensor with shape: ``(n, t-1, (h+out_grid_size-1)//out_grid_size, (w+out_grid_size-1)//out_grid_size, c2)``.
  537. By default, out_grid_size = 4. dtype: int16.
  538. """
  539. assert src.ndim == 5 and src.shape[4] == 4
  540. src = src.detach()
  541. op = builtin.NvOf(precision=precision)
  542. return apply(op, src)[0]