You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930
  1. #include "test/cuda/fixture.h"
  2. #include "megdnn/oprs/linalg.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/random_state.h"
  5. #include "test/common/svd.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. TEST_F(CUDA, SINGULAR_VALUE_DECOMPOSITION) {
  9. auto opr_naive = handle_naive()->create_operator<SVDForward>();
  10. auto opr_cuda = handle_cuda()->create_operator<SVDForward>();
  11. auto testcases = SVDTestcase::make();
  12. for (auto& t : testcases) {
  13. auto cuda_result = t.run(opr_cuda.get());
  14. bool old_compute_nv = t.m_param.compute_uv;
  15. t.m_param.compute_uv = false;
  16. auto naive_result = t.run(opr_naive.get());
  17. t.m_param.compute_uv = old_compute_nv;
  18. MEGDNN_ASSERT_TENSOR_EQ(*naive_result.s, *cuda_result.s);
  19. if (t.m_param.compute_uv) {
  20. MEGDNN_ASSERT_TENSOR_EQ(*cuda_result.recovered_mat, t.m_mat);
  21. }
  22. }
  23. }
  24. // vim: syntax=cpp.doxygen