|
|
@@ -40,10 +40,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
ptr[i] = shp.shape[i]; |
|
|
|
} |
|
|
|
}else{ |
|
|
|
mgb_assert(op_def.axis < shp.ndim); |
|
|
|
int32_t axis = op_def.axis; |
|
|
|
if (axis < 0) { |
|
|
|
axis += shp.ndim; |
|
|
|
} |
|
|
|
mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); |
|
|
|
hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32()); |
|
|
|
auto* ptr = hv.ptr<dt_int32>(); |
|
|
|
ptr[0] = shp.shape[op_def.axis]; |
|
|
|
ptr[0] = shp.shape[axis]; |
|
|
|
} |
|
|
|
return {Tensor::make(std::move(hv))}; |
|
|
|
} |
|
|
@@ -65,10 +69,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
ptr[i] = desc.layout[i]; |
|
|
|
} |
|
|
|
}else{ |
|
|
|
mgb_assert(op_def.axis < desc.layout.ndim); |
|
|
|
int32_t axis = op_def.axis; |
|
|
|
if (axis < 0) { |
|
|
|
axis += desc.layout.ndim; |
|
|
|
} |
|
|
|
mgb_assert(axis >= 0 && axis < (int32_t)desc.layout.ndim); |
|
|
|
value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); |
|
|
|
auto* ptr = value.ptr<dt_int32>(); |
|
|
|
ptr[0] = desc.layout[op_def.axis]; |
|
|
|
ptr[0] = desc.layout[axis]; |
|
|
|
} |
|
|
|
return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; |
|
|
|
} |
|
|
|