You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

named_tensor.cpp 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. #include "megdnn/named_tensor.h"
  2. // clang-format off
  3. #include "test/common/utils.h"
  4. #include "test/common/fix_gtest_on_platforms_without_exception.inl"
  5. // clang-format on
  6. using namespace megdnn;
  7. using megdnn::test::MegDNNError;
  8. TEST(NAMED_TENSOR, NAMED_TENSOR_SHAPE_BASIC) {
  9. ASSERT_EQ(Dimension::NR_NAMES, 10);
  10. Dimension dim0 = {"C"}, dim1 = {"C//32"}, dim2 = {"C//4"}, dim3 = {"C%32"},
  11. dim4 = {"C%4"}, dim5 = {"C//4%8"};
  12. ASSERT_TRUE(dim0 == dim1 * dim3);
  13. ASSERT_TRUE(dim2 == dim1 * dim5);
  14. ASSERT_THROW(dim0 * dim0, MegDNNError);
  15. ASSERT_TRUE(dim1 == dim0 / dim3);
  16. ASSERT_TRUE(dim3 == dim0 / dim1);
  17. ASSERT_TRUE(dim4 == dim3 / dim5);
  18. ASSERT_TRUE(dim5 == dim3 / dim4);
  19. ASSERT_TRUE(dim5 == dim2 / dim1);
  20. ASSERT_THROW(dim5 / dim1, MegDNNError);
  21. ASSERT_TRUE(dim1 < dim4);
  22. ASSERT_TRUE(dim5 < dim4);
  23. ASSERT_FALSE(dim4 < dim5);
  24. ASSERT_TRUE(dim1 < dim2);
  25. ASSERT_FALSE(dim2 < dim1);
  26. auto shape0 =
  27. NamedTensorShape::make_named_tensor_shape(NamedTensorShape::Format::NCHW);
  28. SmallVector<Dimension> dims = {{"N"}, {"C"}, {"H"}, {"W"}};
  29. NamedTensorShape shape1(dims);
  30. NamedTensorShape shape2{{"N"}, {"C"}, {"H"}, {"W"}};
  31. ASSERT_TRUE(shape0.eq_shape(shape1));
  32. ASSERT_TRUE(shape0.eq_shape(shape2));
  33. ASSERT_TRUE(shape1.eq_shape(shape2));
  34. auto shape3 =
  35. NamedTensorShape::make_named_tensor_shape(NamedTensorShape::Format::NCHW4);
  36. ASSERT_FALSE(shape0.eq_shape(shape3));
  37. auto shape4 = NamedTensorShape::make_named_tensor_shape(
  38. NamedTensorShape::Format::NCHW44_DOT);
  39. std::sort(shape4.dims.begin(), shape4.dims.begin() + shape4.ndim);
  40. NamedTensorShape shape5{{"N"}, {"C//32"}, {"H"}, {"W"}, {"C//8%4"}, {"C%8"}};
  41. std::sort(shape5.dims.begin(), shape5.dims.begin() + shape5.ndim);
  42. }
  43. // vim: syntax=cpp.doxygen