diff --git a/third_party/fwkacllib/inc/ops/selection_ops.h b/third_party/fwkacllib/inc/ops/selection_ops.h index 2c99e82e..e815599e 100644 --- a/third_party/fwkacllib/inc/ops/selection_ops.h +++ b/third_party/fwkacllib/inc/ops/selection_ops.h @@ -178,6 +178,30 @@ REG_OP(GatherNd) .OP_END_FACTORY_REG(GatherNd) /** +*@Gathers values along an axis specified by dim . \n + +*@par Inputs: +*@li x: A Tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, +* int64, uint16, float16, uint32, uint64, bool. +*@li dim: A Tensor. Must be one of the following types: int32, int64. +*@li index: A Tensor. Must be one of the following types: int32, int64 . \n + + +*@par Outputs: +* y: A Tensor. Has the same type as "x" . \n + +*@par Third-party framework compatibility +*Compatible with the PyTorch operator Gather. +*/ +REG_OP(GatherD) + .INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, DT_UINT32 + DT_INT64, DT_UINT64, DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(dim, TensorType({DT_INT32, DT_INT64})) + .INPUT(index, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(GatherD) + +/** *@brief Gather slices from "x" according to "indices" by corresponding axis . \n *@par Inputs: