You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

accuracy_shake_checker.cpp 3.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #include "test/common/accuracy_shake_checker.h"
  2. using namespace megdnn;
  3. using namespace test;
  4. namespace {
  5. template <typename ctype>
  6. ::testing::AssertionResult assert_tensor_binary_eq(
  7. const char* expr0, const char* expr1, const char* /*expr2*/, const TensorND& v0,
  8. const TensorND& v1, const std::string& algo_name) {
  9. ctype* it0_orig = v0.ptr<ctype>();
  10. ctype* it1 = v1.ptr<ctype>();
  11. ctype* it0 = it0_orig;
  12. auto nr_elem = v1.layout.total_nr_elems();
  13. auto nr_elem_single_batch = v0.layout.total_nr_elems();
  14. for (size_t i = 0; i < nr_elem; ++i) {
  15. if (i % nr_elem_single_batch == 0) {
  16. it0 = it0_orig;
  17. }
  18. ctype iv0 = *it0, iv1 = *it1;
  19. if (!good_float(iv0) || !good_float(iv1) || memcmp(it0, it1, sizeof(ctype))) {
  20. Index index(v1.layout, i);
  21. return ::testing::AssertionFailure()
  22. << "Unequal value\n"
  23. << "Value of: " << expr1 << "\n"
  24. << " Actual: " << (iv1 + 0) << "\n"
  25. << "Expected: " << expr0 << "\n"
  26. << "Which is: " << (iv0 + 0) << "\n"
  27. << "At index: " << index.to_string() << "/"
  28. << v1.layout.TensorShape::to_string() << "\n"
  29. << " DType: " << v1.layout.dtype.name() << "\n"
  30. << "algo: " << algo_name;
  31. }
  32. ++it0;
  33. ++it1;
  34. }
  35. return ::testing::AssertionSuccess();
  36. }
  37. } // namespace
  38. ::testing::AssertionResult test::__assert_tensor_binary_eq(
  39. const char* expr0, const char* expr1, const char* expr2, const TensorND& v0,
  40. const TensorND& v1, const Algorithm::Info::Desc& algo) {
  41. bool shape_match = v0.layout[0] == 1;
  42. for (size_t i = 1; i < v0.layout.ndim; ++i) {
  43. shape_match &= v0.layout[i] == v1.layout[i];
  44. }
  45. if (!shape_match) {
  46. return ::testing::AssertionFailure()
  47. << "Shape mismatch\n"
  48. << "Value of: " << expr1 << "\n"
  49. << " Actual: " << v1.layout.TensorShape::to_string() << "\n"
  50. << "Expected: " << expr0 << "\n"
  51. << "Which is: " << v0.layout.TensorShape::to_string() << "\n"
  52. << "algo: " << algo.name << "\n";
  53. }
  54. if (!v0.layout.is_physical_contiguous() || !v1.layout.is_physical_contiguous()) {
  55. return ::testing::AssertionFailure()
  56. << "layout should be physical contiguous\n"
  57. << "Value of: " << expr1 << "\n"
  58. << " Actual: " << v1.layout.is_physical_contiguous() << "\n"
  59. << "Expected: " << expr0 << "\n"
  60. << "Which is: " << v0.layout.is_physical_contiguous() << "\n"
  61. << "algo: " << algo.name << "\n";
  62. }
  63. auto dtype = v0.layout.dtype;
  64. if (dtype != v1.layout.dtype) {
  65. return ::testing::AssertionFailure()
  66. << "Data type should match\n"
  67. << "Value of: " << expr1 << "\n"
  68. << " Actual: " << v1.layout.dtype.name() << "\n"
  69. << "Expected: " << expr0 << "\n"
  70. << "Which is: " << v0.layout.dtype.name() << "\n"
  71. << "algo: " << algo.name << "\n";
  72. }
  73. switch (dtype.enumv()) {
  74. #define cb(_dt) \
  75. case DTypeTrait<_dt>::enumv: \
  76. return assert_tensor_binary_eq<DTypeTrait<_dt>::ctype>( \
  77. expr0, expr1, expr2, v0, v1, algo.name);
  78. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  79. MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
  80. #undef cb
  81. default:
  82. megdnn_trap();
  83. }
  84. }
  85. // vim: syntax=cpp.doxygen