Browse Source

Pre Merge pull request !2104 from yanzhenxiang2020/gatherd_yzx

pull/2104/MERGE
yanzhenxiang2020 Gitee 3 years ago
parent
commit
9365958d5f
1 changed files with 24 additions and 0 deletions
  1. +24
    -0
      third_party/fwkacllib/inc/ops/selection_ops.h

+ 24
- 0
third_party/fwkacllib/inc/ops/selection_ops.h View File

@@ -178,6 +178,30 @@ REG_OP(GatherNd)
.OP_END_FACTORY_REG(GatherNd) .OP_END_FACTORY_REG(GatherNd)


/** /**
*@Gathers values along an axis specified by dim . \n

*@par Inputs:
*@li x: A Tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8,
* int64, uint16, float16, uint32, uint64, bool.
*@li dim: A Tensor. Must be one of the following types: int32, int64.
*@li index: A Tensor. Must be one of the following types: int32, int64 . \n


*@par Outputs:
* y: A Tensor. Has the same type as "x" . \n

*@par Third-party framework compatibility
*Compatible with the PyTorch operator Gather.
*/
REG_OP(GatherD)
.INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, DT_UINT32
DT_INT64, DT_UINT64, DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
.INPUT(dim, TensorType({DT_INT32, DT_INT64}))
.INPUT(index, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64}))
.OP_END_FACTORY_REG(GatherD)

/**
*@brief Gather slices from "x" according to "indices" by corresponding axis . \n *@brief Gather slices from "x" according to "indices" by corresponding axis . \n


*@par Inputs: *@par Inputs:


Loading…
Cancel
Save