From b9aaf217a55b96fcbd1c06320abe457e3ee1fb49 Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Fri, 7 Jan 2022 12:14:28 +0800 Subject: [PATCH] add GatherD IR --- third_party/fwkacllib/inc/ops/selection_ops.h | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/third_party/fwkacllib/inc/ops/selection_ops.h b/third_party/fwkacllib/inc/ops/selection_ops.h index b92af69e..04b709bc 100644 --- a/third_party/fwkacllib/inc/ops/selection_ops.h +++ b/third_party/fwkacllib/inc/ops/selection_ops.h @@ -178,6 +178,30 @@ REG_OP(GatherNd) .OP_END_FACTORY_REG(GatherNd) /** +*@Gathers values along an axis specified by dim . \n + +*@par Inputs: +*@li x: A Tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, +* int64, uint16, float16, uint32, uint64, bool. +*@li dim: A Tensor. Must be one of the following types: int32, int64. +*@li index: A Tensor. Must be one of the following types: int32, int64 . \n + + +*@par Outputs: +* y: A Tensor. Has the same type as "x" . \n + +*@par Third-party framework compatibility +*Compatible with the PyTorch operator Gather. +*/ +REG_OP(GatherD) + .INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, DT_UINT32 + DT_INT64, DT_UINT64, DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(dim, TensorType({DT_INT32, DT_INT64})) + .INPUT(index, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(GatherD) + +/** *@brief Gather slices from "x" according to "indices" by corresponding axis . *@par Inputs: