|
- #include "test/common/accuracy_shake_checker.h"
-
- using namespace megdnn;
- using namespace test;
-
- namespace {
-
- template <typename ctype>
- ::testing::AssertionResult assert_tensor_binary_eq(
- const char* expr0, const char* expr1, const char* /*expr2*/, const TensorND& v0,
- const TensorND& v1, const std::string& algo_name) {
- ctype* it0_orig = v0.ptr<ctype>();
- ctype* it1 = v1.ptr<ctype>();
- ctype* it0 = it0_orig;
- auto nr_elem = v1.layout.total_nr_elems();
- auto nr_elem_single_batch = v0.layout.total_nr_elems();
- for (size_t i = 0; i < nr_elem; ++i) {
- if (i % nr_elem_single_batch == 0) {
- it0 = it0_orig;
- }
- ctype iv0 = *it0, iv1 = *it1;
-
- if (!good_float(iv0) || !good_float(iv1) || memcmp(it0, it1, sizeof(ctype))) {
- Index index(v1.layout, i);
- return ::testing::AssertionFailure()
- << "Unequal value\n"
- << "Value of: " << expr1 << "\n"
- << " Actual: " << (iv1 + 0) << "\n"
- << "Expected: " << expr0 << "\n"
- << "Which is: " << (iv0 + 0) << "\n"
- << "At index: " << index.to_string() << "/"
- << v1.layout.TensorShape::to_string() << "\n"
- << " DType: " << v1.layout.dtype.name() << "\n"
- << "algo: " << algo_name;
- }
-
- ++it0;
- ++it1;
- }
-
- return ::testing::AssertionSuccess();
- }
- } // namespace
-
- ::testing::AssertionResult test::__assert_tensor_binary_eq(
- const char* expr0, const char* expr1, const char* expr2, const TensorND& v0,
- const TensorND& v1, const Algorithm::Info::Desc& algo) {
- bool shape_match = v0.layout[0] == 1;
- for (size_t i = 1; i < v0.layout.ndim; ++i) {
- shape_match &= v0.layout[i] == v1.layout[i];
- }
- if (!shape_match) {
- return ::testing::AssertionFailure()
- << "Shape mismatch\n"
- << "Value of: " << expr1 << "\n"
- << " Actual: " << v1.layout.TensorShape::to_string() << "\n"
- << "Expected: " << expr0 << "\n"
- << "Which is: " << v0.layout.TensorShape::to_string() << "\n"
- << "algo: " << algo.name << "\n";
- }
-
- if (!v0.layout.is_physical_contiguous() || !v1.layout.is_physical_contiguous()) {
- return ::testing::AssertionFailure()
- << "layout should be physical contiguous\n"
- << "Value of: " << expr1 << "\n"
- << " Actual: " << v1.layout.is_physical_contiguous() << "\n"
- << "Expected: " << expr0 << "\n"
- << "Which is: " << v0.layout.is_physical_contiguous() << "\n"
- << "algo: " << algo.name << "\n";
- }
- auto dtype = v0.layout.dtype;
- if (dtype != v1.layout.dtype) {
- return ::testing::AssertionFailure()
- << "Data type should match\n"
- << "Value of: " << expr1 << "\n"
- << " Actual: " << v1.layout.dtype.name() << "\n"
- << "Expected: " << expr0 << "\n"
- << "Which is: " << v0.layout.dtype.name() << "\n"
- << "algo: " << algo.name << "\n";
- }
-
- switch (dtype.enumv()) {
- #define cb(_dt) \
- case DTypeTrait<_dt>::enumv: \
- return assert_tensor_binary_eq<DTypeTrait<_dt>::ctype>( \
- expr0, expr1, expr2, v0, v1, algo.name);
- MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
- MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
- #undef cb
- default:
- megdnn_trap();
- }
- }
-
- // vim: syntax=cpp.doxygen
|