|
|
@@ -391,6 +391,7 @@ def _get_idx(index, axis): |
|
|
|
|
|
|
|
|
|
|
|
def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: |
|
|
|
# TODO: rewrite doc |
|
|
|
r"""Gathers data from input tensor on axis using index. |
|
|
|
|
|
|
|
For a 3-D tensor, the output is specified by:: |
|
|
@@ -462,6 +463,7 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: |
|
|
|
|
|
|
|
|
|
|
|
def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: |
|
|
|
# TODO: rewrite doc |
|
|
|
r"""Writes all values from the tensor source into input tensor |
|
|
|
at the indices specified in the index tensor. |
|
|
|
|
|
|
|