|
|
@@ -21,9 +21,13 @@ |
|
|
|
|
|
|
|
namespace mgb::imperative { |
|
|
|
|
|
|
|
// Implement `ToStringTrait` for your printable class |
|
|
|
// note that it should be either implemented in this file |
|
|
|
// or in the same file with your class |
|
|
|
template <typename T> |
|
|
|
struct ToStringTrait; |
|
|
|
|
|
|
|
// Call `to_string` to print your value |
|
|
|
template <typename T> |
|
|
|
std::string to_string(const T& value) { |
|
|
|
return ToStringTrait<T>{}(value); |
|
|
@@ -92,7 +96,6 @@ template <> |
|
|
|
struct ToStringTrait<TensorShape>{ |
|
|
|
std::string operator()(TensorShape shape) const { |
|
|
|
if (shape.ndim > TensorShape::MAX_NDIM) { |
|
|
|
printf("ndim: %d\n", (int)shape.ndim); |
|
|
|
return "[]"; |
|
|
|
} |
|
|
|
mgb_assert(shape.ndim <= TensorShape::MAX_NDIM); |
|
|
|