|
|
@@ -23,8 +23,7 @@ namespace mgb::imperative { |
|
|
|
* |
|
|
|
*/ |
|
|
|
struct ValueShape { |
|
|
|
size_t shape[TensorShape::MAX_NDIM]; |
|
|
|
int ndim = 0; |
|
|
|
size_t shape[TensorShape::MAX_NDIM], ndim = 0; |
|
|
|
|
|
|
|
ValueShape() = default; |
|
|
|
ValueShape(std::initializer_list<size_t> dims) { |
|
|
@@ -70,19 +69,14 @@ struct ValueShape { |
|
|
|
return buffer; |
|
|
|
} |
|
|
|
|
|
|
|
static ValueShape from(TensorShape tensor_shape) { |
|
|
|
static const ValueShape& from(const TensorShape& tensor_shape) { |
|
|
|
mgb_assert(tensor_shape.ndim); |
|
|
|
return Span<size_t>{tensor_shape.shape, tensor_shape.ndim}; |
|
|
|
return reinterpret_cast<const ValueShape&>(tensor_shape); |
|
|
|
} |
|
|
|
|
|
|
|
TensorShape as_tensor_shape() const { |
|
|
|
const TensorShape& as_tensor_shape() const { |
|
|
|
mgb_assert(ndim != 0); |
|
|
|
TensorShape ret; |
|
|
|
for (size_t i = 0; i < ndim; ++i) { |
|
|
|
ret.shape[i] = shape[i]; |
|
|
|
} |
|
|
|
ret.ndim = ndim; |
|
|
|
return ret; |
|
|
|
return reinterpret_cast<const TensorShape&>(*this); |
|
|
|
} |
|
|
|
|
|
|
|
bool operator==(const ValueShape& rhs) const { |
|
|
|