|
|
@@ -249,6 +249,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: |
|
|
|
import numpy as np |
|
|
|
import megengine.functional as F |
|
|
|
from megengine.core import tensor |
|
|
|
|
|
|
|
inp = tensor(np.zeros(shape=(3,5),dtype=np.float32)) |
|
|
|
source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]]) |
|
|
|
index = tensor([[0,2,0,2,1],[2,0,0,1,2]]) |
|
|
@@ -258,6 +259,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: |
|
|
|
Outputs: |
|
|
|
|
|
|
|
.. testoutput:: |
|
|
|
:options: +SKIP |
|
|
|
|
|
|
|
[[0.9935 0.0718 0.5939 0. 0. ] |
|
|
|
[0. 0. 0. 0.357 0.4396] |
|
|
@@ -314,9 +316,9 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: |
|
|
|
def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: |
|
|
|
r""" |
|
|
|
Select elements either from Tensor x or Tensor y, according to mask. |
|
|
|
|
|
|
|
|
|
|
|
.. math:: |
|
|
|
|
|
|
|
|
|
|
|
\textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i |
|
|
|
|
|
|
|
:param mask: a mask used for choosing x or y |
|
|
|