You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

named_tensor.cpp 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. /**
  2. * \file dnn/test/common/named_tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/named_tensor.h"
  13. // clang-format off
  14. #include "test/common/utils.h"
  15. #include "test/common/fix_gtest_on_platforms_without_exception.inl"
  16. // clang-format on
  17. using namespace megdnn;
  18. using megdnn::test::MegDNNError;
  19. TEST(NAMED_TENSOR, NAMED_TENSOR_SHAPE_BASIC) {
  20. ASSERT_EQ(Dimension::NR_NAMES, 10);
  21. Dimension dim0 = {"C"}, dim1 = {"C//32"}, dim2 = {"C//4"}, dim3 = {"C%32"},
  22. dim4 = {"C%4"}, dim5 = {"C//4%8"};
  23. ASSERT_TRUE(dim0 == dim1 * dim3);
  24. ASSERT_TRUE(dim2 == dim1 * dim5);
  25. ASSERT_THROW(dim0 * dim0, MegDNNError);
  26. ASSERT_TRUE(dim1 == dim0 / dim3);
  27. ASSERT_TRUE(dim3 == dim0 / dim1);
  28. ASSERT_TRUE(dim4 == dim3 / dim5);
  29. ASSERT_TRUE(dim5 == dim3 / dim4);
  30. ASSERT_TRUE(dim5 == dim2 / dim1);
  31. ASSERT_THROW(dim5 / dim1, MegDNNError);
  32. ASSERT_TRUE(dim1 < dim4);
  33. ASSERT_TRUE(dim5 < dim4);
  34. ASSERT_FALSE(dim4 < dim5);
  35. ASSERT_TRUE(dim1 < dim2);
  36. ASSERT_FALSE(dim2 < dim1);
  37. auto shape0 =
  38. NamedTensorShape::make_named_tensor_shape(NamedTensorShape::Format::NCHW);
  39. SmallVector<Dimension> dims = {{"N"}, {"C"}, {"H"}, {"W"}};
  40. NamedTensorShape shape1(dims);
  41. NamedTensorShape shape2{{"N"}, {"C"}, {"H"}, {"W"}};
  42. ASSERT_TRUE(shape0.eq_shape(shape1));
  43. ASSERT_TRUE(shape0.eq_shape(shape2));
  44. ASSERT_TRUE(shape1.eq_shape(shape2));
  45. auto shape3 =
  46. NamedTensorShape::make_named_tensor_shape(NamedTensorShape::Format::NCHW4);
  47. ASSERT_FALSE(shape0.eq_shape(shape3));
  48. auto shape4 = NamedTensorShape::make_named_tensor_shape(
  49. NamedTensorShape::Format::NCHW44_DOT);
  50. std::sort(shape4.dims.begin(), shape4.dims.begin() + shape4.ndim);
  51. NamedTensorShape shape5{{"N"}, {"C//32"}, {"H"}, {"W"}, {"C//8%4"}, {"C%8"}};
  52. std::sort(shape5.dims.begin(), shape5.dims.begin() + shape5.ndim);
  53. }
  54. // vim: syntax=cpp.doxygen