Browse Source

feat(opr): add param(axis) for GetVarShape

feat(mge/imperative): GetVarShape support negative axis

GitOrigin-RevId: 30ce0758e6
release-1.2
Megvii Engine Team 4 years ago
parent
commit
d168cea4a7
1 changed files with 12 additions and 4 deletions
  1. +12
    -4
      imperative/src/impl/ops/tensor_manip.cpp

+ 12
- 4
imperative/src/impl/ops/tensor_manip.cpp View File

@@ -40,10 +40,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
ptr[i] = shp.shape[i];
}
}else{
mgb_assert(op_def.axis < shp.ndim);
int32_t axis = op_def.axis;
if (axis < 0) {
axis += shp.ndim;
}
mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = shp.shape[op_def.axis];
ptr[0] = shp.shape[axis];
}
return {Tensor::make(std::move(hv))};
}
@@ -65,10 +69,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
ptr[i] = desc.layout[i];
}
}else{
mgb_assert(op_def.axis < desc.layout.ndim);
int32_t axis = op_def.axis;
if (axis < 0) {
axis += desc.layout.ndim;
}
mgb_assert(axis >= 0 && axis < (int32_t)desc.layout.ndim);
value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
ptr[0] = desc.layout[op_def.axis];
ptr[0] = desc.layout[axis];
}
return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
}


Loading…
Cancel
Save