Browse Source

fix(mge/functional): fix scatter doctest failed for GPU platform issue

GitOrigin-RevId: b5f92c39dd
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
adfa468899
1 changed files with 11 additions and 4 deletions
  1. +11
    -4
      python_module/megengine/functional/tensor.py

+ 11
- 4
python_module/megengine/functional/tensor.py View File

@@ -236,6 +236,14 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:

Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.

.. note::
Please notice that, due to performance issues, the result is uncertain on the GPU device
if scatter difference positions from source to the same destination position
regard to index tensor.

Show the case using the following examples, the oup[0][2] is maybe
from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339
if set the index[1][2] from 1 to 0.

:param inp: the inp tensor which to be scattered
:param axis: the axis along which to index
@@ -252,17 +260,16 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:

inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
index = tensor([[0,2,0,2,1],[2,0,0,1,2]])
index = tensor([[0,2,0,2,1],[2,0,1,1,2]])
oup = F.scatter(inp, 0, index,source)
print(oup.numpy())

Outputs:

.. testoutput::
:options: +SKIP

[[0.9935 0.0718 0.5939 0. 0. ]
[0. 0. 0. 0.357 0.4396]
[[0.9935 0.0718 0.2256 0. 0. ]
[0. 0. 0.5939 0.357 0.4396]
[0.7723 0.9465 0. 0.8926 0.4576]]

"""


Loading…
Cancel
Save