From d168cea4a7ba519759f769a393efc79aed0cb2a0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 14 Dec 2020 14:03:29 +0800 Subject: [PATCH] feat(opr): add param(axis) for GetVarShape feat(mge/imperative): GetVarShape support negative axis GitOrigin-RevId: 30ce0758e66285e984de0bb410759032653cf20a --- imperative/src/impl/ops/tensor_manip.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 21ed0f92..e4a8344f 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -40,10 +40,14 @@ SmallVector apply_on_physical_tensor( ptr[i] = shp.shape[i]; } }else{ - mgb_assert(op_def.axis < shp.ndim); + int32_t axis = op_def.axis; + if (axis < 0) { + axis += shp.ndim; + } + mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32()); auto* ptr = hv.ptr(); - ptr[0] = shp.shape[op_def.axis]; + ptr[0] = shp.shape[axis]; } return {Tensor::make(std::move(hv))}; } @@ -65,10 +69,14 @@ std::tuple, bool> infer_output_attrs_fallible( ptr[i] = desc.layout[i]; } }else{ - mgb_assert(op_def.axis < desc.layout.ndim); + int32_t axis = op_def.axis; + if (axis < 0) { + axis += desc.layout.ndim; + } + mgb_assert(axis >= 0 && axis < (int32_t)desc.layout.ndim); value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); auto* ptr = value.ptr(); - ptr[0] = desc.layout[op_def.axis]; + ptr[0] = desc.layout[axis]; } return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; }